diff --git a/cmd/zz_gen_cmd_dnshelp.go b/cmd/zz_gen_cmd_dnshelp.go index 7c9309c3..dbb610c7 100644 --- a/cmd/zz_gen_cmd_dnshelp.go +++ b/cmd/zz_gen_cmd_dnshelp.go @@ -265,6 +265,7 @@ func displayDNSHelp(name string) error { ew.writeln(`Additional Configuration:`) ew.writeln(` - "AZURE_METADATA_ENDPOINT": Metadata Service endpoint URL`) ew.writeln(` - "AZURE_POLLING_INTERVAL": Time between DNS propagation check`) + ew.writeln(` - "AZURE_PRIVATE_ZONE": Set to true to use Azure Private DNS Zones and not public`) ew.writeln(` - "AZURE_PROPAGATION_TIMEOUT": Maximum waiting time for DNS propagation`) ew.writeln(` - "AZURE_TTL": The TTL of the TXT record used for the DNS challenge`) ew.writeln(` - "AZURE_ZONE_NAME": Zone name to use inside Azure DNS service to add the TXT record in`) diff --git a/docs/content/dns/zz_gen_azure.md b/docs/content/dns/zz_gen_azure.md index be0cd674..51c728cb 100644 --- a/docs/content/dns/zz_gen_azure.md +++ b/docs/content/dns/zz_gen_azure.md @@ -47,6 +47,7 @@ More information [here](/lego/dns/#configuration-and-credentials). |--------------------------------|-------------| | `AZURE_METADATA_ENDPOINT` | Metadata Service endpoint URL | | `AZURE_POLLING_INTERVAL` | Time between DNS propagation check | +| `AZURE_PRIVATE_ZONE` | Set to true to use Azure Private DNS Zones and not public | | `AZURE_PROPAGATION_TIMEOUT` | Maximum waiting time for DNS propagation | | `AZURE_TTL` | The TTL of the TXT record used for the DNS challenge | | `AZURE_ZONE_NAME` | Zone name to use inside Azure DNS service to add the TXT record in | diff --git a/providers/dns/azure/azure.go b/providers/dns/azure/azure.go index cf47a518..33c862a2 100644 --- a/providers/dns/azure/azure.go +++ b/providers/dns/azure/azure.go @@ -3,7 +3,6 @@ package azure import ( - "context" "errors" "fmt" "io" @@ -11,11 +10,10 @@ import ( "strings" "time" - "github.com/Azure/azure-sdk-for-go/services/dns/mgmt/2017-09-01/dns" "github.com/Azure/go-autorest/autorest" aazure "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/azure/auth" - "github.com/Azure/go-autorest/autorest/to" + "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/platform/config/env" ) @@ -34,6 +32,7 @@ const ( EnvClientID = envNamespace + "CLIENT_ID" EnvClientSecret = envNamespace + "CLIENT_SECRET" EnvZoneName = envNamespace + "ZONE_NAME" + EnvPrivateZone = envNamespace + "PRIVATE_ZONE" EnvTTL = envNamespace + "TTL" EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT" @@ -49,6 +48,7 @@ type Config struct { SubscriptionID string ResourceGroup string + PrivateZone bool MetadataEndpoint string ResourceManagerEndpoint string @@ -74,8 +74,7 @@ func NewDefaultConfig() *Config { // DNSProvider implements the challenge.Provider interface. type DNSProvider struct { - config *Config - authorizer autorest.Authorizer + provider challenge.ProviderTimeout } // NewDNSProvider returns a DNSProvider instance configured for azure. @@ -113,6 +112,7 @@ func NewDNSProvider() (*DNSProvider, error) { config.ClientSecret = env.GetOrFile(EnvClientSecret) config.ClientID = env.GetOrFile(EnvClientID) config.TenantID = env.GetOrFile(EnvTenantID) + config.PrivateZone = env.GetOrDefaultBool(EnvPrivateZone, false) return NewDNSProviderConfig(config) } @@ -156,112 +156,27 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) { config.ResourceGroup = resGroup } - return &DNSProvider{config: config, authorizer: authorizer}, nil + if config.PrivateZone { + return &DNSProvider{provider: &dnsProviderPrivate{config: config, authorizer: authorizer}}, nil + } + + return &DNSProvider{provider: &dnsProviderPublic{config: config, authorizer: authorizer}}, nil } // Timeout returns the timeout and interval to use when checking for DNS propagation. // Adjusting here to cope with spikes in propagation times. func (d *DNSProvider) Timeout() (timeout, interval time.Duration) { - return d.config.PropagationTimeout, d.config.PollingInterval + return d.provider.Timeout() } // Present creates a TXT record to fulfill the dns-01 challenge. func (d *DNSProvider) Present(domain, token, keyAuth string) error { - ctx := context.Background() - fqdn, value := dns01.GetRecord(domain, keyAuth) - - zone, err := d.getHostedZoneID(ctx, fqdn) - if err != nil { - return fmt.Errorf("azure: %w", err) - } - - rsc := dns.NewRecordSetsClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) - rsc.Authorizer = d.authorizer - - relative := toRelativeRecord(fqdn, dns01.ToFqdn(zone)) - - // Get existing record set - rset, err := rsc.Get(ctx, d.config.ResourceGroup, zone, relative, dns.TXT) - if err != nil { - var detailed autorest.DetailedError - if !errors.As(err, &detailed) || detailed.StatusCode != http.StatusNotFound { - return fmt.Errorf("azure: %w", err) - } - } - - // Construct unique TXT records using map - uniqRecords := map[string]struct{}{value: {}} - if rset.RecordSetProperties != nil && rset.TxtRecords != nil { - for _, txtRecord := range *rset.TxtRecords { - // Assume Value doesn't contain multiple strings - if txtRecord.Value != nil && len(*txtRecord.Value) > 0 { - uniqRecords[(*txtRecord.Value)[0]] = struct{}{} - } - } - } - - var txtRecords []dns.TxtRecord - for txt := range uniqRecords { - txtRecords = append(txtRecords, dns.TxtRecord{Value: &[]string{txt}}) - } - - rec := dns.RecordSet{ - Name: &relative, - RecordSetProperties: &dns.RecordSetProperties{ - TTL: to.Int64Ptr(int64(d.config.TTL)), - TxtRecords: &txtRecords, - }, - } - - _, err = rsc.CreateOrUpdate(ctx, d.config.ResourceGroup, zone, relative, dns.TXT, rec, "", "") - if err != nil { - return fmt.Errorf("azure: %w", err) - } - return nil + return d.provider.Present(domain, token, keyAuth) } // CleanUp removes the TXT record matching the specified parameters. func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error { - ctx := context.Background() - fqdn, _ := dns01.GetRecord(domain, keyAuth) - - zone, err := d.getHostedZoneID(ctx, fqdn) - if err != nil { - return fmt.Errorf("azure: %w", err) - } - - relative := toRelativeRecord(fqdn, dns01.ToFqdn(zone)) - rsc := dns.NewRecordSetsClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) - rsc.Authorizer = d.authorizer - - _, err = rsc.Delete(ctx, d.config.ResourceGroup, zone, relative, dns.TXT, "") - if err != nil { - return fmt.Errorf("azure: %w", err) - } - return nil -} - -// Checks that azure has a zone for this domain name. -func (d *DNSProvider) getHostedZoneID(ctx context.Context, fqdn string) (string, error) { - if zone := env.GetOrFile(EnvZoneName); zone != "" { - return zone, nil - } - - authZone, err := dns01.FindZoneByFqdn(fqdn) - if err != nil { - return "", err - } - - dc := dns.NewZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) - dc.Authorizer = d.authorizer - - zone, err := dc.Get(ctx, d.config.ResourceGroup, dns01.UnFqdn(authZone)) - if err != nil { - return "", err - } - - // zone.Name shouldn't have a trailing dot(.) - return to.String(zone.Name), nil + return d.provider.CleanUp(domain, token, keyAuth) } // Returns the relative record to the domain. diff --git a/providers/dns/azure/azure.toml b/providers/dns/azure/azure.toml index ae5ef422..164512e9 100644 --- a/providers/dns/azure/azure.toml +++ b/providers/dns/azure/azure.toml @@ -16,11 +16,12 @@ Example = '''''' AZURE_RESOURCE_GROUP = "Resource group" 'instance metadata service' = "If the credentials are **not** set via the environment, then it will attempt to get a bearer token via the [instance metadata service](https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service)." [Configuration.Additional] + AZURE_METADATA_ENDPOINT = "Metadata Service endpoint URL" + AZURE_PRIVATE_ZONE = "Set to true to use Azure Private DNS Zones and not public" + AZURE_ZONE_NAME = "Zone name to use inside Azure DNS service to add the TXT record in" AZURE_POLLING_INTERVAL = "Time between DNS propagation check" AZURE_PROPAGATION_TIMEOUT = "Maximum waiting time for DNS propagation" AZURE_TTL = "The TTL of the TXT record used for the DNS challenge" - AZURE_METADATA_ENDPOINT = "Metadata Service endpoint URL" - AZURE_ZONE_NAME = "Zone name to use inside Azure DNS service to add the TXT record in" [Links] API = "https://docs.microsoft.com/en-us/go/azure/" diff --git a/providers/dns/azure/azure_test.go b/providers/dns/azure/azure_test.go index 2fa04d99..49616836 100644 --- a/providers/dns/azure/azure_test.go +++ b/providers/dns/azure/azure_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/go-acme/lego/v4/platform/tester" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -59,13 +60,16 @@ func TestNewDNSProvider(t *testing.T) { p, err := NewDNSProvider() - if test.expected == "" { - require.NoError(t, err) - require.NotNil(t, p) - require.NotNil(t, p.config) - } else { + if test.expected != "" { require.EqualError(t, err, test.expected) + return } + + require.NoError(t, err) + require.NotNil(t, p) + require.NotNil(t, p.provider) + + assert.IsType(t, p.provider, new(dnsProviderPublic)) }) } } @@ -78,16 +82,27 @@ func TestNewDNSProviderConfig(t *testing.T) { subscriptionID string tenantID string resourceGroup string + privateZone bool handler func(w http.ResponseWriter, r *http.Request) expected string }{ { - desc: "success", + desc: "success (public)", clientID: "A", clientSecret: "B", tenantID: "C", subscriptionID: "D", resourceGroup: "E", + privateZone: false, + }, + { + desc: "success (private)", + clientID: "A", + clientSecret: "B", + tenantID: "C", + subscriptionID: "D", + resourceGroup: "E", + privateZone: true, }, { desc: "SubscriptionID missing", @@ -132,6 +147,7 @@ func TestNewDNSProviderConfig(t *testing.T) { config.SubscriptionID = test.subscriptionID config.TenantID = test.tenantID config.ResourceGroup = test.resourceGroup + config.PrivateZone = test.privateZone mux := http.NewServeMux() server := httptest.NewServer(mux) @@ -146,12 +162,19 @@ func TestNewDNSProviderConfig(t *testing.T) { p, err := NewDNSProviderConfig(config) - if test.expected == "" { - require.NoError(t, err) - require.NotNil(t, p) - require.NotNil(t, p.config) - } else { + if test.expected != "" { require.EqualError(t, err, test.expected) + return + } + + require.NoError(t, err) + require.NotNil(t, p) + require.NotNil(t, p.provider) + + if test.privateZone { + assert.IsType(t, p.provider, new(dnsProviderPrivate)) + } else { + assert.IsType(t, p.provider, new(dnsProviderPublic)) } }) } diff --git a/providers/dns/azure/private.go b/providers/dns/azure/private.go new file mode 100644 index 00000000..3da08c8a --- /dev/null +++ b/providers/dns/azure/private.go @@ -0,0 +1,128 @@ +package azure + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/Azure/azure-sdk-for-go/services/privatedns/mgmt/2018-09-01/privatedns" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/to" + "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/platform/config/env" +) + +// dnsProviderPrivate implements the challenge.Provider interface for Azure Private Zone DNS. +type dnsProviderPrivate struct { + config *Config + authorizer autorest.Authorizer +} + +// Timeout returns the timeout and interval to use when checking for DNS propagation. +// Adjusting here to cope with spikes in propagation times. +func (d *dnsProviderPrivate) Timeout() (timeout, interval time.Duration) { + return d.config.PropagationTimeout, d.config.PollingInterval +} + +// Present creates a TXT record to fulfill the dns-01 challenge. +func (d *dnsProviderPrivate) Present(domain, token, keyAuth string) error { + ctx := context.Background() + fqdn, value := dns01.GetRecord(domain, keyAuth) + + zone, err := d.getHostedZoneID(ctx, fqdn) + if err != nil { + return fmt.Errorf("azure: %w", err) + } + + rsc := privatedns.NewRecordSetsClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) + rsc.Authorizer = d.authorizer + + relative := toRelativeRecord(fqdn, dns01.ToFqdn(zone)) + + // Get existing record set + rset, err := rsc.Get(ctx, d.config.ResourceGroup, zone, privatedns.TXT, relative) + if err != nil { + var detailed autorest.DetailedError + if !errors.As(err, &detailed) || detailed.StatusCode != http.StatusNotFound { + return fmt.Errorf("azure: %w", err) + } + } + + // Construct unique TXT records using map + uniqRecords := map[string]struct{}{value: {}} + if rset.RecordSetProperties != nil && rset.TxtRecords != nil { + for _, txtRecord := range *rset.TxtRecords { + // Assume Value doesn't contain multiple strings + values := to.StringSlice(txtRecord.Value) + if len(values) > 0 { + uniqRecords[values[0]] = struct{}{} + } + } + } + + var txtRecords []privatedns.TxtRecord + for txt := range uniqRecords { + txtRecords = append(txtRecords, privatedns.TxtRecord{Value: &[]string{txt}}) + } + + rec := privatedns.RecordSet{ + Name: &relative, + RecordSetProperties: &privatedns.RecordSetProperties{ + TTL: to.Int64Ptr(int64(d.config.TTL)), + TxtRecords: &txtRecords, + }, + } + + _, err = rsc.CreateOrUpdate(ctx, d.config.ResourceGroup, zone, privatedns.TXT, relative, rec, "", "") + if err != nil { + return fmt.Errorf("azure: %w", err) + } + return nil +} + +// CleanUp removes the TXT record matching the specified parameters. +func (d *dnsProviderPrivate) CleanUp(domain, token, keyAuth string) error { + ctx := context.Background() + fqdn, _ := dns01.GetRecord(domain, keyAuth) + + zone, err := d.getHostedZoneID(ctx, fqdn) + if err != nil { + return fmt.Errorf("azure: %w", err) + } + + relative := toRelativeRecord(fqdn, dns01.ToFqdn(zone)) + + rsc := privatedns.NewRecordSetsClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) + rsc.Authorizer = d.authorizer + + _, err = rsc.Delete(ctx, d.config.ResourceGroup, zone, privatedns.TXT, relative, "") + if err != nil { + return fmt.Errorf("azure: %w", err) + } + return nil +} + +// Checks that azure has a zone for this domain name. +func (d *dnsProviderPrivate) getHostedZoneID(ctx context.Context, fqdn string) (string, error) { + if zone := env.GetOrFile(EnvZoneName); zone != "" { + return zone, nil + } + + authZone, err := dns01.FindZoneByFqdn(fqdn) + if err != nil { + return "", err + } + + dc := privatedns.NewPrivateZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) + dc.Authorizer = d.authorizer + + zone, err := dc.Get(ctx, d.config.ResourceGroup, dns01.UnFqdn(authZone)) + if err != nil { + return "", err + } + + // zone.Name shouldn't have a trailing dot(.) + return to.String(zone.Name), nil +} diff --git a/providers/dns/azure/public.go b/providers/dns/azure/public.go new file mode 100644 index 00000000..26a5efe8 --- /dev/null +++ b/providers/dns/azure/public.go @@ -0,0 +1,128 @@ +package azure + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/Azure/azure-sdk-for-go/services/dns/mgmt/2017-09-01/dns" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/to" + "github.com/go-acme/lego/v4/challenge/dns01" + "github.com/go-acme/lego/v4/platform/config/env" +) + +// dnsProviderPublic implements the challenge.Provider interface for Azure Public Zone DNS. +type dnsProviderPublic struct { + config *Config + authorizer autorest.Authorizer +} + +// Timeout returns the timeout and interval to use when checking for DNS propagation. +// Adjusting here to cope with spikes in propagation times. +func (d *dnsProviderPublic) Timeout() (timeout, interval time.Duration) { + return d.config.PropagationTimeout, d.config.PollingInterval +} + +// Present creates a TXT record to fulfill the dns-01 challenge. +func (d *dnsProviderPublic) Present(domain, token, keyAuth string) error { + ctx := context.Background() + fqdn, value := dns01.GetRecord(domain, keyAuth) + + zone, err := d.getHostedZoneID(ctx, fqdn) + if err != nil { + return fmt.Errorf("azure: %w", err) + } + + rsc := dns.NewRecordSetsClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) + rsc.Authorizer = d.authorizer + + relative := toRelativeRecord(fqdn, dns01.ToFqdn(zone)) + + // Get existing record set + rset, err := rsc.Get(ctx, d.config.ResourceGroup, zone, relative, dns.TXT) + if err != nil { + var detailed autorest.DetailedError + if !errors.As(err, &detailed) || detailed.StatusCode != http.StatusNotFound { + return fmt.Errorf("azure: %w", err) + } + } + + // Construct unique TXT records using map + uniqRecords := map[string]struct{}{value: {}} + if rset.RecordSetProperties != nil && rset.TxtRecords != nil { + for _, txtRecord := range *rset.TxtRecords { + // Assume Value doesn't contain multiple strings + values := to.StringSlice(txtRecord.Value) + if len(values) > 0 { + uniqRecords[values[0]] = struct{}{} + } + } + } + + var txtRecords []dns.TxtRecord + for txt := range uniqRecords { + txtRecords = append(txtRecords, dns.TxtRecord{Value: &[]string{txt}}) + } + + rec := dns.RecordSet{ + Name: &relative, + RecordSetProperties: &dns.RecordSetProperties{ + TTL: to.Int64Ptr(int64(d.config.TTL)), + TxtRecords: &txtRecords, + }, + } + + _, err = rsc.CreateOrUpdate(ctx, d.config.ResourceGroup, zone, relative, dns.TXT, rec, "", "") + if err != nil { + return fmt.Errorf("azure: %w", err) + } + return nil +} + +// CleanUp removes the TXT record matching the specified parameters. +func (d *dnsProviderPublic) CleanUp(domain, token, keyAuth string) error { + ctx := context.Background() + fqdn, _ := dns01.GetRecord(domain, keyAuth) + + zone, err := d.getHostedZoneID(ctx, fqdn) + if err != nil { + return fmt.Errorf("azure: %w", err) + } + + relative := toRelativeRecord(fqdn, dns01.ToFqdn(zone)) + + rsc := dns.NewRecordSetsClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) + rsc.Authorizer = d.authorizer + + _, err = rsc.Delete(ctx, d.config.ResourceGroup, zone, relative, dns.TXT, "") + if err != nil { + return fmt.Errorf("azure: %w", err) + } + return nil +} + +// Checks that azure has a zone for this domain name. +func (d *dnsProviderPublic) getHostedZoneID(ctx context.Context, fqdn string) (string, error) { + if zone := env.GetOrFile(EnvZoneName); zone != "" { + return zone, nil + } + + authZone, err := dns01.FindZoneByFqdn(fqdn) + if err != nil { + return "", err + } + + dc := dns.NewZonesClientWithBaseURI(d.config.ResourceManagerEndpoint, d.config.SubscriptionID) + dc.Authorizer = d.authorizer + + zone, err := dc.Get(ctx, d.config.ResourceGroup, dns01.UnFqdn(authZone)) + if err != nil { + return "", err + } + + // zone.Name shouldn't have a trailing dot(.) + return to.String(zone.Name), nil +}