From c97b5a52a13532fe637e26b7bd00ba313fc1c51c Mon Sep 17 00:00:00 2001
From: Jan Broer <janeczku@yahoo.de>
Date: Wed, 3 Feb 2016 05:03:03 +0100
Subject: [PATCH] Refactor DNS check

* Gets a list of all authoritative nameservers by looking up the NS RRs for the root domain (zone apex)
* Verifies that the expected TXT record exists on all nameservers before sending off the challenge to ACME server
---
 acme/dns_challenge.go         | 151 +++++++++++++++++++++++++---------
 acme/dns_challenge_route53.go |   2 +-
 acme/dns_challenge_test.go    | 141 ++++++++++++++++++++++++++++++-
 3 files changed, 252 insertions(+), 42 deletions(-)

diff --git a/acme/dns_challenge.go b/acme/dns_challenge.go
index f34fcccc..b0753499 100644
--- a/acme/dns_challenge.go
+++ b/acme/dns_challenge.go
@@ -6,17 +6,18 @@ import (
 	"errors"
 	"fmt"
 	"log"
+	"net"
 	"strings"
 	"time"
 
 	"github.com/miekg/dns"
 )
 
-type preCheckDNSFunc func(domain, fqdn string) bool
+type preCheckDNSFunc func(domain, fqdn, value string) error
 
-var preCheckDNS preCheckDNSFunc = checkDNS
+var preCheckDNS preCheckDNSFunc = checkDnsPropagation
 
-var preCheckDNSFallbackCount = 5
+var recursionMaxDepth = 10
 
 // DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
 func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
@@ -60,50 +61,121 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
 		}
 	}()
 
-	fqdn, _, _ := DNS01Record(domain, keyAuth)
+	fqdn, value, _ := DNS01Record(domain, keyAuth)
 
-	preCheckDNS(domain, fqdn)
+	logf("[INFO][%s] Checking DNS record propagation...", domain)
+
+	if err = preCheckDNS(domain, fqdn, value); err != nil {
+		return err
+	}
 
 	return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
 }
 
-func checkDNS(domain, fqdn string) bool {
-	// check if the expected DNS entry was created. If not wait for some time and try again.
-	m := new(dns.Msg)
-	m.SetQuestion(domain+".", dns.TypeSOA)
-	c := new(dns.Client)
-	in, _, err := c.Exchange(m, "google-public-dns-a.google.com:53")
+// checkDnsPropagation checks if the expected TXT record has been propagated to
+// all authoritative nameservers. If not it waits and retries for some time.
+func checkDnsPropagation(domain, fqdn, value string) error {
+	authoritativeNss, err := lookupNameservers(toFqdn(domain))
 	if err != nil {
-		return false
+		return err
 	}
 
-	var authorativeNS string
-	for _, answ := range in.Answer {
-		soa := answ.(*dns.SOA)
-		authorativeNS = soa.Ns
+	if err = waitFor(30, 2, func() (bool, error) {
+		return checkAuthoritativeNss(fqdn, value, authoritativeNss)
+	}); err != nil {
+		return err
 	}
 
-	fallbackCnt := 0
-	for fallbackCnt < preCheckDNSFallbackCount {
-		m.SetQuestion(fqdn, dns.TypeTXT)
-		in, _, err = c.Exchange(m, authorativeNS+":53")
+	return nil
+}
+
+// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
+func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
+	for _, ns := range nameservers {
+		r, err := dnsQuery(fqdn, dns.TypeTXT, ns, false)
 		if err != nil {
-			return false
+			return false, err
 		}
 
-		if len(in.Answer) > 0 {
-			return true
+		if r.Rcode != dns.RcodeSuccess {
+			return false, fmt.Errorf("%s returned RCode %s", ns, dns.RcodeToString[r.Rcode])
 		}
 
-		fallbackCnt++
-		if fallbackCnt >= preCheckDNSFallbackCount {
-			return false
+		var found bool
+		for _, rr := range r.Answer {
+			if txt, ok := rr.(*dns.TXT); ok {
+				if strings.Join(txt.Txt, "") == value {
+					found = true
+					break
+				}
+			}
 		}
 
-		time.Sleep(time.Second * time.Duration(fallbackCnt))
+		if !found {
+			return false, fmt.Errorf("%s did not return the expected TXT record", ns)
+		}
 	}
 
-	return false
+	return true, nil
+}
+
+// dnsQuery sends a DNS query to the given nameserver.
+func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in *dns.Msg, err error) {
+	m := new(dns.Msg)
+	m.SetQuestion(fqdn, rtype)
+	m.SetEdns0(4096, false)
+	if !recursive {
+		m.RecursionDesired = false
+	}
+
+	in, err = dns.Exchange(m, net.JoinHostPort(nameserver, "53"))
+	if err == dns.ErrTruncated {
+		tcp := &dns.Client{Net: "tcp"}
+		in, _, err = tcp.Exchange(m, nameserver)
+	}
+
+	return
+}
+
+// lookupNameservers returns the authoritative nameservers for the given domain name.
+func lookupNameservers(fqdn string) ([]string, error) {
+	var err error
+	var r *dns.Msg
+	var authoritativeNss []string
+	resolver := "google-public-dns-a.google.com"
+
+	r, err = dnsQuery(fqdn, dns.TypeSOA, resolver, true)
+	if err != nil {
+		return nil, err
+	}
+
+	// If there is a SOA RR in the Answer section then fqdn is the root domain.
+	for _, rr := range r.Answer {
+		if soa, ok := rr.(*dns.SOA); ok {
+			r, err = dnsQuery(soa.Hdr.Name, dns.TypeNS, resolver, true)
+			if err != nil {
+				return nil, err
+			}
+
+			for _, rr := range r.Answer {
+				if ns, ok := rr.(*dns.NS); ok {
+					authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
+				}
+			}
+
+			return authoritativeNss, nil
+		}
+	}
+
+	// Strip of the left most label to get the parent domain.
+	offset, _ := dns.NextLabel(fqdn, 0)
+	next := fqdn[offset:]
+	// Only the TLD label left. This should not happen if the domain DNS is healthy.
+	if dns.CountLabel(next) < 2 {
+		return nil, fmt.Errorf("Could not determine root domain")
+	}
+	
+	return lookupNameservers(fqdn[offset:])
 }
 
 // toFqdn converts the name into a fqdn appending a trailing dot.
@@ -124,22 +196,25 @@ func unFqdn(name string) string {
 	return name
 }
 
-// waitFor polls the given function 'f', once per second, up to 'timeout' seconds.
-func waitFor(timeout int, f func() (bool, error)) error {
-	start := time.Now().Second()
+// waitFor polls the given function 'f', once every 'interval' seconds, up to 'timeout' seconds.
+func waitFor(timeout, interval int, f func() (bool, error)) error {
+	var lastErr string
+	timeup := time.After(time.Duration(timeout) * time.Second)
 	for {
-		time.Sleep(1 * time.Second)
-
-		if delta := time.Now().Second() - start; delta >= timeout {
-			return fmt.Errorf("Time limit exceeded (%d seconds)", delta)
+		select {
+		case <-timeup:
+			return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr)
+		default:
 		}
 
 		stop, err := f()
-		if err != nil {
-			return err
-		}
 		if stop {
 			return nil
 		}
+		if err != nil {
+			lastErr = err.Error()
+		}
+
+		time.Sleep(time.Duration(interval) * time.Second)
 	}
 }
diff --git a/acme/dns_challenge_route53.go b/acme/dns_challenge_route53.go
index 28ed0259..491117e2 100644
--- a/acme/dns_challenge_route53.go
+++ b/acme/dns_challenge_route53.go
@@ -68,7 +68,7 @@ func (r *DNSProviderRoute53) changeRecord(action, fqdn, value string, ttl int) e
 		return err
 	}
 
-	return waitFor(90, func() (bool, error) {
+	return waitFor(90, 5, func() (bool, error) {
 		status, err := r.client.GetChange(resp.ChangeInfo.ID)
 		if err != nil {
 			return false, err
diff --git a/acme/dns_challenge_test.go b/acme/dns_challenge_test.go
index 0af40f71..046f792a 100644
--- a/acme/dns_challenge_test.go
+++ b/acme/dns_challenge_test.go
@@ -6,13 +6,75 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"os"
+	"reflect"
+	"sort"
+	"strings"
 	"testing"
 	"time"
 )
 
+var lookupNameserversTestsOK = []struct {
+	fqdn string
+	nss  []string
+}{
+	{"books.google.com.ng.",
+		[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
+	},
+	{"www.google.com.",
+		[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
+	},
+	{"physics.georgetown.edu.",
+		[]string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."},
+	},
+}
+
+var lookupNameserversTestsErr = []struct {
+	fqdn  string
+	error string
+}{
+	// invalid tld
+	{"_null.n0n0.",
+		"Could not determine root domain",
+	},
+	// invalid domain
+	{"_null.com.",
+		"Could not determine root domain",
+	},
+}
+
+var checkAuthoritativeNssTests = []struct {
+	fqdn, value string
+	ns          []string
+	ok          bool
+}{
+	// TXT RR w/ expected value
+	{"8.8.8.8.asn.routeviews.org.", "151698.8.8.024", []string{"asnums.routeviews.org."},
+		true,
+	},
+	// No TXT RR
+	{"ns1.google.com.", "", []string{"ns2.google.com."},
+		false,
+	},
+}
+
+var checkAuthoritativeNssTestsErr = []struct {
+	fqdn, value string
+	ns          []string
+	error       string
+}{
+	// TXT RR /w unexpected value
+	{"8.8.8.8.asn.routeviews.org.", "fe01=", []string{"asnums.routeviews.org."},
+		"did not return the expected TXT record",
+	},
+	// No TXT RR
+	{"ns1.google.com.", "fe01=", []string{"ns2.google.com."},
+		"did not return the expected TXT record",
+	},
+}
+
 func TestDNSValidServerResponse(t *testing.T) {
-	preCheckDNS = func(domain, fqdn string) bool {
-		return true
+	preCheckDNS = func(domain, fqdn, value string) error {
+		return nil
 	}
 	privKey, _ := generatePrivateKey(rsakey, 512)
 
@@ -39,7 +101,80 @@ func TestDNSValidServerResponse(t *testing.T) {
 }
 
 func TestPreCheckDNS(t *testing.T) {
-	if !preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org") {
+	err := preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org", "fe01=")
+	if err != nil {
 		t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org")
 	}
 }
+
+func TestLookupNameserversOK(t *testing.T) {
+	for _, tt := range lookupNameserversTestsOK {
+		nss, err := lookupNameservers(tt.fqdn)
+		if err != nil {
+			t.Fatalf("#%s: got %q; want nil", tt.fqdn, err)
+		}
+
+		sort.Strings(nss)
+		sort.Strings(tt.nss)
+
+		if !reflect.DeepEqual(nss, tt.nss) {
+			t.Errorf("#%s: got %v; want %v", tt.fqdn, nss, tt.nss)
+		}
+	}
+}
+
+func TestLookupNameserversErr(t *testing.T) {
+	for _, tt := range lookupNameserversTestsErr {
+		_, err := lookupNameservers(tt.fqdn)
+		if err == nil {
+			t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
+		}
+
+		if !strings.Contains(err.Error(), tt.error) {
+			t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
+			continue
+		}
+	}
+}
+
+func TestCheckAuthoritativeNss(t *testing.T) {
+	for _, tt := range checkAuthoritativeNssTests {
+		ok, _ := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
+		if ok != tt.ok {
+			t.Errorf("#%s: got %t; want %t", tt.fqdn, tt.ok)
+		}
+	}
+}
+
+func TestCheckAuthoritativeNssErr(t *testing.T) {
+	for _, tt := range checkAuthoritativeNssTestsErr {
+		_, err := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
+		if err == nil {
+			t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
+		}
+		if !strings.Contains(err.Error(), tt.error) {
+			t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
+			continue
+		}
+	}
+}
+
+func TestWaitForTimeout(t *testing.T) {
+	c := make(chan error)
+	go func() {
+		err := waitFor(3, 1, func() (bool, error) {
+			return false, nil
+		})
+		c <- err
+	}()
+
+	timeout := time.After(4 * time.Second)
+	select {
+	case <-timeout:
+		t.Fatal("timeout exceeded")
+	case err := <-c:
+		if err == nil {
+			t.Errorf("expected timeout error; got <nil>", err)
+		}
+	}
+}