Retry logic for dnsQuery

Added a slice of NS to be used when retrying queries. Also used with FindZoneByFqdn()
Adjusted 2 error messages given to better differentiate the returned error string
This commit is contained in:
LukeHandle 2016-04-11 23:59:59 +01:00
parent 74c6bbee86
commit dbad97ebc6
9 changed files with 40 additions and 28 deletions

View file

@ -21,7 +21,10 @@ var (
fqdnToZone = map[string]string{} fqdnToZone = map[string]string{}
) )
var RecursiveNameserver = "google-public-dns-a.google.com:53" var RecursiveNameservers = []string{
"google-public-dns-a.google.com:53",
"google-public-dns-b.google.com:53",
}
// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge // DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) { func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
@ -56,12 +59,12 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
err = s.provider.Present(domain, chlng.Token, keyAuth) err = s.provider.Present(domain, chlng.Token, keyAuth)
if err != nil { if err != nil {
return fmt.Errorf("Error presenting token %s", err) return fmt.Errorf("Error presenting token: %s", err)
} }
defer func() { defer func() {
err := s.provider.CleanUp(domain, chlng.Token, keyAuth) err := s.provider.CleanUp(domain, chlng.Token, keyAuth)
if err != nil { if err != nil {
log.Printf("Error cleaning up %s %v ", domain, err) log.Printf("Error cleaning up %s: %v ", domain, err)
} }
}() }()
@ -90,7 +93,7 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers. // checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
func checkDNSPropagation(fqdn, value string) (bool, error) { func checkDNSPropagation(fqdn, value string) (bool, error) {
// Initial attempt to resolve at the recursive NS // Initial attempt to resolve at the recursive NS
r, err := dnsQuery(fqdn, dns.TypeTXT, RecursiveNameserver, true) r, err := dnsQuery(fqdn, dns.TypeTXT, RecursiveNameservers, true)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -117,7 +120,7 @@ func checkDNSPropagation(fqdn, value string) (bool, error) {
// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record. // checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) { func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
for _, ns := range nameservers { for _, ns := range nameservers {
r, err := dnsQuery(fqdn, dns.TypeTXT, net.JoinHostPort(ns, "53"), false) r, err := dnsQuery(fqdn, dns.TypeTXT, []string{net.JoinHostPort(ns, "53")}, false)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -144,9 +147,9 @@ func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, erro
return true, nil return true, nil
} }
// dnsQuery sends a DNS query to the given nameserver. // dnsQuery will query a nameserver, iterating through the supplied servers as it retries
// The nameserver should include a port, to facilitate testing where we talk to a mock dns server. // The nameserver should include a port, to facilitate testing where we talk to a mock dns server.
func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in *dns.Msg, err error) { func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(fqdn, rtype) m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false) m.SetEdns0(4096, false)
@ -155,12 +158,21 @@ func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in
m.RecursionDesired = false m.RecursionDesired = false
} }
in, err = dns.Exchange(m, nameserver) // Will rety the request based on the number of servers (n+1)
if err == dns.ErrTruncated { for i := 1; i <= len(nameservers)+1; i++ {
tcp := &dns.Client{Net: "tcp"} ns := nameservers[i%len(nameservers)]
in, _, err = tcp.Exchange(m, nameserver) in, err = dns.Exchange(m, ns)
}
if err == dns.ErrTruncated {
tcp := &dns.Client{Net: "tcp"}
// If the TCP request suceeds, the err will reset to nil
in, _, err = tcp.Exchange(m, ns)
}
if err == nil {
break
}
}
return return
} }
@ -168,12 +180,12 @@ func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in
func lookupNameservers(fqdn string) ([]string, error) { func lookupNameservers(fqdn string) ([]string, error) {
var authoritativeNss []string var authoritativeNss []string
zone, err := FindZoneByFqdn(fqdn, RecursiveNameserver) zone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameserver, true) r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameservers, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -191,7 +203,7 @@ func lookupNameservers(fqdn string) ([]string, error) {
} }
// FindZoneByFqdn determines the zone of the given fqdn // FindZoneByFqdn determines the zone of the given fqdn
func FindZoneByFqdn(fqdn, nameserver string) (string, error) { func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
// Do we have it cached? // Do we have it cached?
if zone, ok := fqdnToZone[fqdn]; ok { if zone, ok := fqdnToZone[fqdn]; ok {
return zone, nil return zone, nil
@ -203,13 +215,13 @@ func FindZoneByFqdn(fqdn, nameserver string) (string, error) {
// Name servers authoritative for a zone MUST include the SOA record of // Name servers authoritative for a zone MUST include the SOA record of
// the zone in the authority section of the response when reporting an // the zone in the authority section of the response when reporting an
// NXDOMAIN or indicating that no data (NODATA) of the requested type exists // NXDOMAIN or indicating that no data (NODATA) of the requested type exists
in, err := dnsQuery(fqdn, dns.TypeSOA, nameserver, true) in, err := dnsQuery(fqdn, dns.TypeSOA, nameservers, true)
if err != nil { if err != nil {
return "", err return "", err
} }
if in.Rcode != dns.RcodeNameError { if in.Rcode != dns.RcodeNameError {
if in.Rcode != dns.RcodeSuccess { if in.Rcode != dns.RcodeSuccess {
return "", fmt.Errorf("NS %s returned %s for %s", nameserver, dns.RcodeToString[in.Rcode], fqdn) return "", fmt.Errorf("The NS returned %s for %s", dns.RcodeToString[in.Rcode], fqdn)
} }
// We have a success, so one of the answers has to be a SOA RR // We have a success, so one of the answers has to be a SOA RR
for _, ans := range in.Answer { for _, ans := range in.Answer {
@ -225,7 +237,7 @@ func FindZoneByFqdn(fqdn, nameserver string) (string, error) {
return checkIfTLD(fqdn, soa) return checkIfTLD(fqdn, soa)
} }
} }
return "", fmt.Errorf("NS %s did not return the expected SOA record in the authority section", nameserver) return "", fmt.Errorf("The NS did not return the expected SOA record in the authority section")
} }
func checkIfTLD(fqdn string, soa *dns.SOA) (string, error) { func checkIfTLD(fqdn string, soa *dns.SOA) (string, error) {

View file

@ -23,7 +23,7 @@ func (*DNSProviderManual) Present(domain, token, keyAuth string) error {
fqdn, value, ttl := DNS01Record(domain, keyAuth) fqdn, value, ttl := DNS01Record(domain, keyAuth)
dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, value) dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, value)
authZone, err := FindZoneByFqdn(fqdn, RecursiveNameserver) authZone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
if err != nil { if err != nil {
return err return err
} }
@ -42,7 +42,7 @@ func (*DNSProviderManual) CleanUp(domain, token, keyAuth string) error {
fqdn, _, ttl := DNS01Record(domain, keyAuth) fqdn, _, ttl := DNS01Record(domain, keyAuth)
dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, "...") dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, "...")
authZone, err := FindZoneByFqdn(fqdn, RecursiveNameserver) authZone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
if err != nil { if err != nil {
return err return err
} }

View file

@ -104,7 +104,7 @@ func (c *DNSProvider) getHostedZoneID(fqdn string) (string, error) {
Name string `json:"name"` Name string `json:"name"`
} }
authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameserver) authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameservers)
if err != nil { if err != nil {
return "", err return "", err
} }

View file

@ -63,7 +63,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
fqdn, value, _ := acme.DNS01Record(domain, keyAuth) fqdn, value, _ := acme.DNS01Record(domain, keyAuth)
authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameserver) authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameservers)
if err != nil { if err != nil {
return fmt.Errorf("Could not determine zone for domain: '%s'. %s", domain, err) return fmt.Errorf("Could not determine zone for domain: '%s'. %s", domain, err)
} }
@ -122,7 +122,7 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return fmt.Errorf("unknown record ID for '%s'", fqdn) return fmt.Errorf("unknown record ID for '%s'", fqdn)
} }
authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameserver) authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameservers)
if err != nil { if err != nil {
return fmt.Errorf("Could not determine zone for domain: '%s'. %s", domain, err) return fmt.Errorf("Could not determine zone for domain: '%s'. %s", domain, err)
} }

View file

@ -79,7 +79,7 @@ func (c *DNSProvider) getHostedZone(domain string) (string, string, error) {
return "", "", fmt.Errorf("DNSimple API call failed: %v", err) return "", "", fmt.Errorf("DNSimple API call failed: %v", err)
} }
authZone, err := acme.FindZoneByFqdn(domain, acme.RecursiveNameserver) authZone, err := acme.FindZoneByFqdn(domain, acme.RecursiveNameservers)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }

View file

@ -75,7 +75,7 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {
ttl = 300 // 300 is gandi minimum value for ttl ttl = 300 // 300 is gandi minimum value for ttl
} }
// find authZone and Gandi zone_id for fqdn // find authZone and Gandi zone_id for fqdn
authZone, err := findZoneByFqdn(fqdn, acme.RecursiveNameserver) authZone, err := findZoneByFqdn(fqdn, acme.RecursiveNameservers)
if err != nil { if err != nil {
return fmt.Errorf("Gandi DNS: findZoneByFqdn failure: %v", err) return fmt.Errorf("Gandi DNS: findZoneByFqdn failure: %v", err)
} }

View file

@ -71,7 +71,7 @@ func TestDNSProvider(t *testing.T) {
})) }))
defer fakeServer.Close() defer fakeServer.Close()
// define function to override findZoneByFqdn with // define function to override findZoneByFqdn with
fakeFindZoneByFqdn := func(fqdn, nameserver string) (string, error) { fakeFindZoneByFqdn := func(fqdn string, nameserver []string) (string, error) {
return "example.com.", nil return "example.com.", nil
} }
// override gandi endpoint and findZoneByFqdn function // override gandi endpoint and findZoneByFqdn function

View file

@ -82,7 +82,7 @@ func (r *DNSProvider) CleanUp(domain, token, keyAuth string) error {
func (r *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error { func (r *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error {
// Find the zone for the given fqdn // Find the zone for the given fqdn
zone, err := acme.FindZoneByFqdn(fqdn, r.nameserver) zone, err := acme.FindZoneByFqdn(fqdn, []string{r.nameserver})
if err != nil { if err != nil {
return err return err
} }

View file

@ -124,7 +124,7 @@ func (r *DNSProvider) changeRecord(action, fqdn, value string, ttl int) error {
} }
func (r *DNSProvider) getHostedZoneID(fqdn string) (string, error) { func (r *DNSProvider) getHostedZoneID(fqdn string) (string, error) {
authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameserver) authZone, err := acme.FindZoneByFqdn(fqdn, acme.RecursiveNameservers)
if err != nil { if err != nil {
return "", err return "", err
} }