From c97b5a52a13532fe637e26b7bd00ba313fc1c51c Mon Sep 17 00:00:00 2001 From: Jan Broer Date: Wed, 3 Feb 2016 05:03:03 +0100 Subject: [PATCH] Refactor DNS check * Gets a list of all authoritative nameservers by looking up the NS RRs for the root domain (zone apex) * Verifies that the expected TXT record exists on all nameservers before sending off the challenge to ACME server --- acme/dns_challenge.go | 151 +++++++++++++++++++++++++--------- acme/dns_challenge_route53.go | 2 +- acme/dns_challenge_test.go | 141 ++++++++++++++++++++++++++++++- 3 files changed, 252 insertions(+), 42 deletions(-) diff --git a/acme/dns_challenge.go b/acme/dns_challenge.go index f34fcccc..b0753499 100644 --- a/acme/dns_challenge.go +++ b/acme/dns_challenge.go @@ -6,17 +6,18 @@ import ( "errors" "fmt" "log" + "net" "strings" "time" "github.com/miekg/dns" ) -type preCheckDNSFunc func(domain, fqdn string) bool +type preCheckDNSFunc func(domain, fqdn, value string) error -var preCheckDNS preCheckDNSFunc = checkDNS +var preCheckDNS preCheckDNSFunc = checkDnsPropagation -var preCheckDNSFallbackCount = 5 +var recursionMaxDepth = 10 // DNS01Record returns a DNS record which will fulfill the `dns-01` challenge func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) { @@ -60,50 +61,121 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error { } }() - fqdn, _, _ := DNS01Record(domain, keyAuth) + fqdn, value, _ := DNS01Record(domain, keyAuth) - preCheckDNS(domain, fqdn) + logf("[INFO][%s] Checking DNS record propagation...", domain) + + if err = preCheckDNS(domain, fqdn, value); err != nil { + return err + } return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth}) } -func checkDNS(domain, fqdn string) bool { - // check if the expected DNS entry was created. If not wait for some time and try again. - m := new(dns.Msg) - m.SetQuestion(domain+".", dns.TypeSOA) - c := new(dns.Client) - in, _, err := c.Exchange(m, "google-public-dns-a.google.com:53") +// checkDnsPropagation checks if the expected TXT record has been propagated to +// all authoritative nameservers. If not it waits and retries for some time. +func checkDnsPropagation(domain, fqdn, value string) error { + authoritativeNss, err := lookupNameservers(toFqdn(domain)) if err != nil { - return false + return err } - var authorativeNS string - for _, answ := range in.Answer { - soa := answ.(*dns.SOA) - authorativeNS = soa.Ns + if err = waitFor(30, 2, func() (bool, error) { + return checkAuthoritativeNss(fqdn, value, authoritativeNss) + }); err != nil { + return err } - fallbackCnt := 0 - for fallbackCnt < preCheckDNSFallbackCount { - m.SetQuestion(fqdn, dns.TypeTXT) - in, _, err = c.Exchange(m, authorativeNS+":53") + return nil +} + +// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record. +func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) { + for _, ns := range nameservers { + r, err := dnsQuery(fqdn, dns.TypeTXT, ns, false) if err != nil { - return false + return false, err } - if len(in.Answer) > 0 { - return true + if r.Rcode != dns.RcodeSuccess { + return false, fmt.Errorf("%s returned RCode %s", ns, dns.RcodeToString[r.Rcode]) } - fallbackCnt++ - if fallbackCnt >= preCheckDNSFallbackCount { - return false + var found bool + for _, rr := range r.Answer { + if txt, ok := rr.(*dns.TXT); ok { + if strings.Join(txt.Txt, "") == value { + found = true + break + } + } } - time.Sleep(time.Second * time.Duration(fallbackCnt)) + if !found { + return false, fmt.Errorf("%s did not return the expected TXT record", ns) + } } - return false + return true, nil +} + +// dnsQuery sends a DNS query to the given nameserver. +func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in *dns.Msg, err error) { + m := new(dns.Msg) + m.SetQuestion(fqdn, rtype) + m.SetEdns0(4096, false) + if !recursive { + m.RecursionDesired = false + } + + in, err = dns.Exchange(m, net.JoinHostPort(nameserver, "53")) + if err == dns.ErrTruncated { + tcp := &dns.Client{Net: "tcp"} + in, _, err = tcp.Exchange(m, nameserver) + } + + return +} + +// lookupNameservers returns the authoritative nameservers for the given domain name. +func lookupNameservers(fqdn string) ([]string, error) { + var err error + var r *dns.Msg + var authoritativeNss []string + resolver := "google-public-dns-a.google.com" + + r, err = dnsQuery(fqdn, dns.TypeSOA, resolver, true) + if err != nil { + return nil, err + } + + // If there is a SOA RR in the Answer section then fqdn is the root domain. + for _, rr := range r.Answer { + if soa, ok := rr.(*dns.SOA); ok { + r, err = dnsQuery(soa.Hdr.Name, dns.TypeNS, resolver, true) + if err != nil { + return nil, err + } + + for _, rr := range r.Answer { + if ns, ok := rr.(*dns.NS); ok { + authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns)) + } + } + + return authoritativeNss, nil + } + } + + // Strip of the left most label to get the parent domain. + offset, _ := dns.NextLabel(fqdn, 0) + next := fqdn[offset:] + // Only the TLD label left. This should not happen if the domain DNS is healthy. + if dns.CountLabel(next) < 2 { + return nil, fmt.Errorf("Could not determine root domain") + } + + return lookupNameservers(fqdn[offset:]) } // toFqdn converts the name into a fqdn appending a trailing dot. @@ -124,22 +196,25 @@ func unFqdn(name string) string { return name } -// waitFor polls the given function 'f', once per second, up to 'timeout' seconds. -func waitFor(timeout int, f func() (bool, error)) error { - start := time.Now().Second() +// waitFor polls the given function 'f', once every 'interval' seconds, up to 'timeout' seconds. +func waitFor(timeout, interval int, f func() (bool, error)) error { + var lastErr string + timeup := time.After(time.Duration(timeout) * time.Second) for { - time.Sleep(1 * time.Second) - - if delta := time.Now().Second() - start; delta >= timeout { - return fmt.Errorf("Time limit exceeded (%d seconds)", delta) + select { + case <-timeup: + return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr) + default: } stop, err := f() - if err != nil { - return err - } if stop { return nil } + if err != nil { + lastErr = err.Error() + } + + time.Sleep(time.Duration(interval) * time.Second) } } diff --git a/acme/dns_challenge_route53.go b/acme/dns_challenge_route53.go index 28ed0259..491117e2 100644 --- a/acme/dns_challenge_route53.go +++ b/acme/dns_challenge_route53.go @@ -68,7 +68,7 @@ func (r *DNSProviderRoute53) changeRecord(action, fqdn, value string, ttl int) e return err } - return waitFor(90, func() (bool, error) { + return waitFor(90, 5, func() (bool, error) { status, err := r.client.GetChange(resp.ChangeInfo.ID) if err != nil { return false, err diff --git a/acme/dns_challenge_test.go b/acme/dns_challenge_test.go index 0af40f71..046f792a 100644 --- a/acme/dns_challenge_test.go +++ b/acme/dns_challenge_test.go @@ -6,13 +6,75 @@ import ( "net/http" "net/http/httptest" "os" + "reflect" + "sort" + "strings" "testing" "time" ) +var lookupNameserversTestsOK = []struct { + fqdn string + nss []string +}{ + {"books.google.com.ng.", + []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, + }, + {"www.google.com.", + []string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."}, + }, + {"physics.georgetown.edu.", + []string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."}, + }, +} + +var lookupNameserversTestsErr = []struct { + fqdn string + error string +}{ + // invalid tld + {"_null.n0n0.", + "Could not determine root domain", + }, + // invalid domain + {"_null.com.", + "Could not determine root domain", + }, +} + +var checkAuthoritativeNssTests = []struct { + fqdn, value string + ns []string + ok bool +}{ + // TXT RR w/ expected value + {"8.8.8.8.asn.routeviews.org.", "151698.8.8.024", []string{"asnums.routeviews.org."}, + true, + }, + // No TXT RR + {"ns1.google.com.", "", []string{"ns2.google.com."}, + false, + }, +} + +var checkAuthoritativeNssTestsErr = []struct { + fqdn, value string + ns []string + error string +}{ + // TXT RR /w unexpected value + {"8.8.8.8.asn.routeviews.org.", "fe01=", []string{"asnums.routeviews.org."}, + "did not return the expected TXT record", + }, + // No TXT RR + {"ns1.google.com.", "fe01=", []string{"ns2.google.com."}, + "did not return the expected TXT record", + }, +} + func TestDNSValidServerResponse(t *testing.T) { - preCheckDNS = func(domain, fqdn string) bool { - return true + preCheckDNS = func(domain, fqdn, value string) error { + return nil } privKey, _ := generatePrivateKey(rsakey, 512) @@ -39,7 +101,80 @@ func TestDNSValidServerResponse(t *testing.T) { } func TestPreCheckDNS(t *testing.T) { - if !preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org") { + err := preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org", "fe01=") + if err != nil { t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org") } } + +func TestLookupNameserversOK(t *testing.T) { + for _, tt := range lookupNameserversTestsOK { + nss, err := lookupNameservers(tt.fqdn) + if err != nil { + t.Fatalf("#%s: got %q; want nil", tt.fqdn, err) + } + + sort.Strings(nss) + sort.Strings(tt.nss) + + if !reflect.DeepEqual(nss, tt.nss) { + t.Errorf("#%s: got %v; want %v", tt.fqdn, nss, tt.nss) + } + } +} + +func TestLookupNameserversErr(t *testing.T) { + for _, tt := range lookupNameserversTestsErr { + _, err := lookupNameservers(tt.fqdn) + if err == nil { + t.Fatalf("#%s: expected %q (error); got ", tt.fqdn, tt.error) + } + + if !strings.Contains(err.Error(), tt.error) { + t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err) + continue + } + } +} + +func TestCheckAuthoritativeNss(t *testing.T) { + for _, tt := range checkAuthoritativeNssTests { + ok, _ := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns) + if ok != tt.ok { + t.Errorf("#%s: got %t; want %t", tt.fqdn, tt.ok) + } + } +} + +func TestCheckAuthoritativeNssErr(t *testing.T) { + for _, tt := range checkAuthoritativeNssTestsErr { + _, err := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns) + if err == nil { + t.Fatalf("#%s: expected %q (error); got ", tt.fqdn, tt.error) + } + if !strings.Contains(err.Error(), tt.error) { + t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err) + continue + } + } +} + +func TestWaitForTimeout(t *testing.T) { + c := make(chan error) + go func() { + err := waitFor(3, 1, func() (bool, error) { + return false, nil + }) + c <- err + }() + + timeout := time.After(4 * time.Second) + select { + case <-timeout: + t.Fatal("timeout exceeded") + case err := <-c: + if err == nil { + t.Errorf("expected timeout error; got ", err) + } + } +}