forked from TrueCloudLab/lego
Merge pull request #186 from LukeHandle/patch-dns-retryquery
Retry logic for dnsQuery
This commit is contained in:
commit
cbca761215
9 changed files with 40 additions and 28 deletions
|
@ -21,7 +21,10 @@ var (
|
|||
fqdnToZone = map[string]string{}
|
||||
)
|
||||
|
||||
var RecursiveNameserver = "google-public-dns-a.google.com:53"
|
||||
var RecursiveNameservers = []string{
|
||||
"google-public-dns-a.google.com:53",
|
||||
"google-public-dns-b.google.com:53",
|
||||
}
|
||||
|
||||
// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
|
||||
func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
|
||||
|
@ -56,12 +59,12 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
|
|||
|
||||
err = s.provider.Present(domain, chlng.Token, keyAuth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error presenting token %s", err)
|
||||
return fmt.Errorf("Error presenting token: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := s.provider.CleanUp(domain, chlng.Token, keyAuth)
|
||||
if err != nil {
|
||||
log.Printf("Error cleaning up %s %v ", domain, err)
|
||||
log.Printf("Error cleaning up %s: %v ", domain, err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -90,7 +93,7 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
|
|||
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
|
||||
func checkDNSPropagation(fqdn, value string) (bool, error) {
|
||||
// Initial attempt to resolve at the recursive NS
|
||||
r, err := dnsQuery(fqdn, dns.TypeTXT, RecursiveNameserver, true)
|
||||
r, err := dnsQuery(fqdn, dns.TypeTXT, RecursiveNameservers, true)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -117,7 +120,7 @@ func checkDNSPropagation(fqdn, value string) (bool, error) {
|
|||
// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
|
||||
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
|
||||
for _, ns := range nameservers {
|
||||
r, err := dnsQuery(fqdn, dns.TypeTXT, net.JoinHostPort(ns, "53"), false)
|
||||
r, err := dnsQuery(fqdn, dns.TypeTXT, []string{net.JoinHostPort(ns, "53")}, false)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -144,9 +147,9 @@ func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, erro
|
|||
return true, nil
|
||||
}
|
||||
|
||||
// dnsQuery sends a DNS query to the given nameserver.
|
||||
// dnsQuery will query a nameserver, iterating through the supplied servers as it retries
|
||||
// The nameserver should include a port, to facilitate testing where we talk to a mock dns server.
|
||||
func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in *dns.Msg, err error) {
|
||||
func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(fqdn, rtype)
|
||||
m.SetEdns0(4096, false)
|
||||
|
@ -155,12 +158,21 @@ func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in
|
|||
m.RecursionDesired = false
|
||||
}
|
||||
|
||||
in, err = dns.Exchange(m, nameserver)
|
||||
if err == dns.ErrTruncated {
|
||||
tcp := &dns.Client{Net: "tcp"}
|
||||
in, _, err = tcp.Exchange(m, nameserver)
|
||||
}
|
||||
// Will retry the request based on the number of servers (n+1)
|
||||
for i := 1; i <= len(nameservers)+1; i++ {
|
||||
ns := nameservers[i%len(nameservers)]
|
||||
in, err = dns.Exchange(m, ns)
|
||||
|
||||
if err == dns.ErrTruncated {
|
||||
tcp := &dns.Client{Net: "tcp"}
|
||||
// If the TCP request suceeds, the err will reset to nil
|
||||
in, _, err = tcp.Exchange(m, ns)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -168,12 +180,12 @@ func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in
|
|||
func lookupNameservers(fqdn string) ([]string, error) {
|
||||
var authoritativeNss []string
|
||||
|
||||
zone, err := FindZoneByFqdn(fqdn, RecursiveNameserver)
|
||||
zone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameserver, true)
|
||||
r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameservers, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -191,7 +203,7 @@ func lookupNameservers(fqdn string) ([]string, error) {
|
|||
}
|
||||
|
||||
// FindZoneByFqdn determines the zone of the given fqdn
|
||||
func FindZoneByFqdn(fqdn, nameserver string) (string, error) {
|
||||
func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
|
||||
// Do we have it cached?
|
||||
if zone, ok := fqdnToZone[fqdn]; ok {
|
||||
return zone, nil
|
||||
|
@ -203,13 +215,13 @@ func FindZoneByFqdn(fqdn, nameserver string) (string, error) {
|
|||
// Name servers authoritative for a zone MUST include the SOA record of
|
||||
// the zone in the authority section of the response when reporting an
|
||||
// NXDOMAIN or indicating that no data (NODATA) of the requested type exists
|
||||
in, err := dnsQuery(fqdn, dns.TypeSOA, nameserver, true)
|
||||
in, err := dnsQuery(fqdn, dns.TypeSOA, nameservers, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if in.Rcode != dns.RcodeNameError {
|
||||
if in.Rcode != dns.RcodeSuccess {
|
||||
return "", fmt.Errorf("NS %s returned %s for %s", nameserver, dns.RcodeToString[in.Rcode], fqdn)
|
||||
return "", fmt.Errorf("The NS returned %s for %s", dns.RcodeToString[in.Rcode], fqdn)
|
||||
}
|
||||
// We have a success, so one of the answers has to be a SOA RR
|
||||
for _, ans := range in.Answer {
|
||||
|
@ -225,7 +237,7 @@ func FindZoneByFqdn(fqdn, nameserver string) (string, error) {
|
|||
return checkIfTLD(fqdn, soa)
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("NS %s did not return the expected SOA record in the authority section", nameserver)
|
||||
return "", fmt.Errorf("The NS did not return the expected SOA record in the authority section")
|
||||
}
|
||||
|
||||
func checkIfTLD(fqdn string, soa *dns.SOA) (string, error) {
|
||||
|
|
|
@ -23,7 +23,7 @@ func (*DNSProviderManual) Present(domain, token, keyAuth string) error {
|
|||
fqdn, value, ttl := DNS01Record(domain, keyAuth)
|
||||
dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, value)
|
||||
|
||||
authZone, err := FindZoneByFqdn(fqdn, RecursiveNameserver)
|
||||
authZone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ func (*DNSProviderManual) CleanUp(domain, token, keyAuth string) error {
|
|||
fqdn, _, ttl := DNS01Record(domain, keyAuth)
|
||||
dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, "...")
|
||||
|
||||
authZone, err := FindZoneByFqdn(fqdn, RecursiveNameserver)
|
||||
authZone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -104,7 +104,7 @@ func (c *DNSProvider) getHostedZoneID(fqdn string) (string, error) {
|
|||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameserver)
|
||||
authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameservers)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
|
|||
|
||||
fqdn, value, _ := acme.DNS01Record(domain, keyAuth)
|
||||
|
||||
authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameserver)
|
||||
authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameservers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not determine zone for domain: '%s'. %s", domain, err)
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
|
|||
return fmt.Errorf("unknown record ID for '%s'", fqdn)
|
||||
}
|
||||
|
||||
authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameserver)
|
||||
authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameservers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not determine zone for domain: '%s'. %s", domain, err)
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ func (c *DNSProvider) getHostedZone(domain string) (string, string, error) {
|
|||
return "", "", fmt.Errorf("DNSimple API call failed: %v", err)
|
||||
}
|
||||
|
||||
authZone, err := acme.FindZoneByFqdn(domain, acme.RecursiveNameserver)
|
||||
authZone, err := acme.FindZoneByFqdn(domain, acme.RecursiveNameservers)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
|
|||
ttl = 300 // 300 is gandi minimum value for ttl
|
||||
}
|
||||
// find authZone and Gandi zone_id for fqdn
|
||||
authZone, err := findZoneByFqdn(fqdn, acme.RecursiveNameserver)
|
||||
authZone, err := findZoneByFqdn(fqdn, acme.RecursiveNameservers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Gandi DNS: findZoneByFqdn failure: %v", err)
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ func TestDNSProvider(t *testing.T) {
|
|||
}))
|
||||
defer fakeServer.Close()
|
||||
// define function to override findZoneByFqdn with
|
||||
fakeFindZoneByFqdn := func(fqdn, nameserver string) (string, error) {
|
||||
fakeFindZoneByFqdn := func(fqdn string, nameserver []string) (string, error) {
|
||||
return "example.com.", nil
|
||||
}
|
||||
// override gandi endpoint and findZoneByFqdn function
|
||||
|
|
|
@ -82,7 +82,7 @@ func (r *DNSProvider) CleanUp(domain, token, keyAuth string) error {
|
|||
|
||||
func (r *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error {
|
||||
// Find the zone for the given fqdn
|
||||
zone, err := acme.FindZoneByFqdn(fqdn, r.nameserver)
|
||||
zone, err := acme.FindZoneByFqdn(fqdn, []string{r.nameserver})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -124,7 +124,7 @@ func (r *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error {
|
|||
}
|
||||
|
||||
func (r *DNSProvider) getHostedZoneID(fqdn string) (string, error) {
|
||||
authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameserver)
|
||||
authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameservers)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue