forked from TrueCloudLab/lego
188 lines
5.7 KiB
Go
188 lines
5.7 KiB
Go
package azuredns
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns"
|
|
"github.com/go-acme/lego/v4/challenge/dns01"
|
|
)
|
|
|
|
// DNSProviderPublic implements the challenge.Provider interface for Azure Public Zone DNS.
|
|
type DNSProviderPublic struct {
|
|
config *Config
|
|
credentials azcore.TokenCredential
|
|
serviceDiscoveryZones map[string]ServiceDiscoveryZone
|
|
}
|
|
|
|
// NewDNSProviderPublic creates a DNSProviderPublic structure.
|
|
func NewDNSProviderPublic(config *Config, credentials azcore.TokenCredential) (*DNSProviderPublic, error) {
|
|
zones, err := discoverDNSZones(context.Background(), config, credentials)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("discover DNS zones: %w", err)
|
|
}
|
|
|
|
return &DNSProviderPublic{
|
|
config: config,
|
|
credentials: credentials,
|
|
serviceDiscoveryZones: zones,
|
|
}, nil
|
|
}
|
|
|
|
// Timeout returns the timeout and interval to use when checking for DNS propagation.
|
|
// Adjusting here to cope with spikes in propagation times.
|
|
func (d *DNSProviderPublic) Timeout() (timeout, interval time.Duration) {
|
|
return d.config.PropagationTimeout, d.config.PollingInterval
|
|
}
|
|
|
|
// Present creates a TXT record to fulfill the dns-01 challenge.
|
|
func (d *DNSProviderPublic) Present(domain, _, keyAuth string) error {
|
|
ctx := context.Background()
|
|
info := dns01.GetChallengeInfo(domain, keyAuth)
|
|
|
|
zone, err := d.getHostedZone(info.EffectiveFQDN)
|
|
if err != nil {
|
|
return fmt.Errorf("azuredns: %w", err)
|
|
}
|
|
|
|
client, err := newPublicZoneClient(zone, d.credentials, d.config.Environment)
|
|
if err != nil {
|
|
return fmt.Errorf("azuredns: %w", err)
|
|
}
|
|
|
|
subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone.Name)
|
|
if err != nil {
|
|
return fmt.Errorf("azuredns: %w", err)
|
|
}
|
|
|
|
// Get existing record set
|
|
resp, err := client.Get(ctx, subDomain)
|
|
if err != nil {
|
|
var respErr *azcore.ResponseError
|
|
if !errors.As(err, &respErr) || respErr.StatusCode != http.StatusNotFound {
|
|
return fmt.Errorf("azuredns: %w", err)
|
|
}
|
|
}
|
|
|
|
uniqRecords := publicUniqueRecords(resp.RecordSet, info.Value)
|
|
|
|
var txtRecords []*armdns.TxtRecord
|
|
for txt := range uniqRecords {
|
|
txtRecords = append(txtRecords, &armdns.TxtRecord{Value: to.SliceOfPtrs(txt)})
|
|
}
|
|
|
|
rec := armdns.RecordSet{
|
|
Name: &subDomain,
|
|
Properties: &armdns.RecordSetProperties{
|
|
TTL: to.Ptr(int64(d.config.TTL)),
|
|
TxtRecords: txtRecords,
|
|
},
|
|
}
|
|
|
|
_, err = client.CreateOrUpdate(ctx, subDomain, rec)
|
|
if err != nil {
|
|
return fmt.Errorf("azuredns: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CleanUp removes the TXT record matching the specified parameters.
|
|
func (d *DNSProviderPublic) CleanUp(domain, _, keyAuth string) error {
|
|
ctx := context.Background()
|
|
info := dns01.GetChallengeInfo(domain, keyAuth)
|
|
|
|
zone, err := d.getHostedZone(info.EffectiveFQDN)
|
|
if err != nil {
|
|
return fmt.Errorf("azuredns: %w", err)
|
|
}
|
|
|
|
client, err := newPublicZoneClient(zone, d.credentials, d.config.Environment)
|
|
if err != nil {
|
|
return fmt.Errorf("azuredns: %w", err)
|
|
}
|
|
|
|
subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, zone.Name)
|
|
if err != nil {
|
|
return fmt.Errorf("azuredns: %w", err)
|
|
}
|
|
|
|
_, err = client.Delete(ctx, subDomain)
|
|
if err != nil {
|
|
return fmt.Errorf("azuredns: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Checks that azure has a zone for this domain name.
|
|
func (d *DNSProviderPublic) getHostedZone(fqdn string) (ServiceDiscoveryZone, error) {
|
|
authZone, err := getAuthZone(fqdn)
|
|
if err != nil {
|
|
return ServiceDiscoveryZone{}, err
|
|
}
|
|
|
|
azureZone, exists := d.serviceDiscoveryZones[dns01.UnFqdn(authZone)]
|
|
if !exists {
|
|
return ServiceDiscoveryZone{}, fmt.Errorf("could not find zone (from discovery): %s", authZone)
|
|
}
|
|
|
|
return azureZone, nil
|
|
}
|
|
|
|
type publicZoneClient struct {
|
|
zone ServiceDiscoveryZone
|
|
recordClient *armdns.RecordSetsClient
|
|
}
|
|
|
|
// newPublicZoneClient creates publicZoneClient structure with initialized Azure client.
|
|
func newPublicZoneClient(zone ServiceDiscoveryZone, credential azcore.TokenCredential, environment cloud.Configuration) (*publicZoneClient, error) {
|
|
options := &arm.ClientOptions{
|
|
ClientOptions: azcore.ClientOptions{
|
|
Cloud: environment,
|
|
},
|
|
}
|
|
|
|
recordClient, err := armdns.NewRecordSetsClient(zone.SubscriptionID, credential, options)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &publicZoneClient{
|
|
zone: zone,
|
|
recordClient: recordClient,
|
|
}, nil
|
|
}
|
|
|
|
func (c publicZoneClient) Get(ctx context.Context, subDomain string) (armdns.RecordSetsClientGetResponse, error) {
|
|
return c.recordClient.Get(ctx, c.zone.ResourceGroup, c.zone.Name, subDomain, armdns.RecordTypeTXT, nil)
|
|
}
|
|
|
|
func (c publicZoneClient) CreateOrUpdate(ctx context.Context, subDomain string, rec armdns.RecordSet) (armdns.RecordSetsClientCreateOrUpdateResponse, error) {
|
|
return c.recordClient.CreateOrUpdate(ctx, c.zone.ResourceGroup, c.zone.Name, subDomain, armdns.RecordTypeTXT, rec, nil)
|
|
}
|
|
|
|
func (c publicZoneClient) Delete(ctx context.Context, subDomain string) (armdns.RecordSetsClientDeleteResponse, error) {
|
|
return c.recordClient.Delete(ctx, c.zone.ResourceGroup, c.zone.Name, subDomain, armdns.RecordTypeTXT, nil)
|
|
}
|
|
|
|
func publicUniqueRecords(recordSet armdns.RecordSet, value string) map[string]struct{} {
|
|
uniqRecords := map[string]struct{}{value: {}}
|
|
if recordSet.Properties != nil && recordSet.Properties.TxtRecords != nil {
|
|
for _, txtRecord := range recordSet.Properties.TxtRecords {
|
|
// Assume Value doesn't contain multiple strings
|
|
if len(txtRecord.Value) > 0 {
|
|
uniqRecords[deref(txtRecord.Value[0])] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
|
|
return uniqRecords
|
|
}
|