commit
ba64faa4e1
3 changed files with 256 additions and 42 deletions
|
@ -6,17 +6,18 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type preCheckDNSFunc func(domain, fqdn string) bool
|
type preCheckDNSFunc func(fqdn, value string) (bool, error)
|
||||||
|
|
||||||
var preCheckDNS preCheckDNSFunc = checkDNS
|
var preCheckDNS preCheckDNSFunc = checkDnsPropagation
|
||||||
|
|
||||||
var preCheckDNSFallbackCount = 5
|
var recursiveNameserver = "google-public-dns-a.google.com"
|
||||||
|
|
||||||
// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
|
// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
|
||||||
func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
|
func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
|
||||||
|
@ -60,50 +61,125 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
fqdn, _, _ := DNS01Record(domain, keyAuth)
|
fqdn, value, _ := DNS01Record(domain, keyAuth)
|
||||||
|
|
||||||
preCheckDNS(domain, fqdn)
|
logf("[INFO][%s] Checking DNS record propagation...", domain)
|
||||||
|
|
||||||
|
err = waitFor(30, 2, func() (bool, error) {
|
||||||
|
return preCheckDNS(fqdn, value)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
|
return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkDNS(domain, fqdn string) bool {
|
// checkDnsPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
|
||||||
// check if the expected DNS entry was created. If not wait for some time and try again.
|
func checkDnsPropagation(fqdn, value string) (bool, error) {
|
||||||
m := new(dns.Msg)
|
// Initial attempt to resolve at the recursive NS
|
||||||
m.SetQuestion(domain+".", dns.TypeSOA)
|
r, err := dnsQuery(fqdn, dns.TypeTXT, recursiveNameserver, true)
|
||||||
c := new(dns.Client)
|
|
||||||
in, _, err := c.Exchange(m, "google-public-dns-a.google.com:53")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false, err
|
||||||
|
}
|
||||||
|
if r.Rcode != dns.RcodeSuccess {
|
||||||
|
return false, fmt.Errorf("Could not resolve %s -> %s", fqdn, dns.RcodeToString[r.Rcode])
|
||||||
}
|
}
|
||||||
|
|
||||||
var authorativeNS string
|
// If we see a CNAME here then use the alias
|
||||||
for _, answ := range in.Answer {
|
for _, rr := range r.Answer {
|
||||||
soa := answ.(*dns.SOA)
|
if cn, ok := rr.(*dns.CNAME); ok {
|
||||||
authorativeNS = soa.Ns
|
if cn.Hdr.Name == fqdn {
|
||||||
|
fqdn = cn.Target
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fallbackCnt := 0
|
authoritativeNss, err := lookupNameservers(fqdn)
|
||||||
for fallbackCnt < preCheckDNSFallbackCount {
|
if err != nil {
|
||||||
m.SetQuestion(fqdn, dns.TypeTXT)
|
return false, err
|
||||||
in, _, err = c.Exchange(m, authorativeNS+":53")
|
}
|
||||||
|
|
||||||
|
return checkAuthoritativeNss(fqdn, value, authoritativeNss)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
|
||||||
|
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
|
||||||
|
for _, ns := range nameservers {
|
||||||
|
r, err := dnsQuery(fqdn, dns.TypeTXT, ns, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(in.Answer) > 0 {
|
if r.Rcode != dns.RcodeSuccess {
|
||||||
return true
|
return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
|
||||||
}
|
}
|
||||||
|
|
||||||
fallbackCnt++
|
var found bool
|
||||||
if fallbackCnt >= preCheckDNSFallbackCount {
|
for _, rr := range r.Answer {
|
||||||
return false
|
if txt, ok := rr.(*dns.TXT); ok {
|
||||||
|
if strings.Join(txt.Txt, "") == value {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(time.Second * time.Duration(fallbackCnt))
|
if !found {
|
||||||
|
return false, fmt.Errorf("NS %s did not return the expected TXT record", ns)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsQuery sends a DNS query to the given nameserver.
|
||||||
|
func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in *dns.Msg, err error) {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion(fqdn, rtype)
|
||||||
|
m.SetEdns0(4096, false)
|
||||||
|
|
||||||
|
if !recursive {
|
||||||
|
m.RecursionDesired = false
|
||||||
|
}
|
||||||
|
|
||||||
|
in, err = dns.Exchange(m, net.JoinHostPort(nameserver, "53"))
|
||||||
|
if err == dns.ErrTruncated {
|
||||||
|
tcp := &dns.Client{Net: "tcp"}
|
||||||
|
in, _, err = tcp.Exchange(m, nameserver)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupNameservers returns the authoritative nameservers for the given fqdn.
|
||||||
|
func lookupNameservers(fqdn string) ([]string, error) {
|
||||||
|
var authoritativeNss []string
|
||||||
|
|
||||||
|
r, err := dnsQuery(fqdn, dns.TypeNS, recursiveNameserver, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rr := range r.Answer {
|
||||||
|
if ns, ok := rr.(*dns.NS); ok {
|
||||||
|
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(authoritativeNss) > 0 {
|
||||||
|
return authoritativeNss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip of the left most label to get the parent domain.
|
||||||
|
offset, _ := dns.NextLabel(fqdn, 0)
|
||||||
|
next := fqdn[offset:]
|
||||||
|
if dns.CountLabel(next) < 2 {
|
||||||
|
return nil, fmt.Errorf("Could not determine authoritative nameservers")
|
||||||
|
}
|
||||||
|
|
||||||
|
return lookupNameservers(next)
|
||||||
}
|
}
|
||||||
|
|
||||||
// toFqdn converts the name into a fqdn appending a trailing dot.
|
// toFqdn converts the name into a fqdn appending a trailing dot.
|
||||||
|
@ -124,22 +200,25 @@ func unFqdn(name string) string {
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
// waitFor polls the given function 'f', once per second, up to 'timeout' seconds.
|
// waitFor polls the given function 'f', once every 'interval' seconds, up to 'timeout' seconds.
|
||||||
func waitFor(timeout int, f func() (bool, error)) error {
|
func waitFor(timeout, interval int, f func() (bool, error)) error {
|
||||||
start := time.Now().Second()
|
var lastErr string
|
||||||
|
timeup := time.After(time.Duration(timeout) * time.Second)
|
||||||
for {
|
for {
|
||||||
time.Sleep(1 * time.Second)
|
select {
|
||||||
|
case <-timeup:
|
||||||
if delta := time.Now().Second() - start; delta >= timeout {
|
return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr)
|
||||||
return fmt.Errorf("Time limit exceeded (%d seconds)", delta)
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
stop, err := f()
|
stop, err := f()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if stop {
|
if stop {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Duration(interval) * time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,7 @@ func (r *DNSProviderRoute53) changeRecord(action, fqdn, value string, ttl int) e
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return waitFor(90, func() (bool, error) {
|
return waitFor(90, 5, func() (bool, error) {
|
||||||
status, err := r.client.GetChange(resp.ChangeInfo.ID)
|
status, err := r.client.GetChange(resp.ChangeInfo.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|
|
@ -6,13 +6,75 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var lookupNameserversTestsOK = []struct {
|
||||||
|
fqdn string
|
||||||
|
nss []string
|
||||||
|
}{
|
||||||
|
{"books.google.com.ng.",
|
||||||
|
[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
|
||||||
|
},
|
||||||
|
{"www.google.com.",
|
||||||
|
[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
|
||||||
|
},
|
||||||
|
{"physics.georgetown.edu.",
|
||||||
|
[]string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var lookupNameserversTestsErr = []struct {
|
||||||
|
fqdn string
|
||||||
|
error string
|
||||||
|
}{
|
||||||
|
// invalid tld
|
||||||
|
{"_null.n0n0.",
|
||||||
|
"Could not determine authoritative nameservers",
|
||||||
|
},
|
||||||
|
// invalid domain
|
||||||
|
{"_null.com.",
|
||||||
|
"Could not determine authoritative nameservers",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var checkAuthoritativeNssTests = []struct {
|
||||||
|
fqdn, value string
|
||||||
|
ns []string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
// TXT RR w/ expected value
|
||||||
|
{"8.8.8.8.asn.routeviews.org.", "151698.8.8.024", []string{"asnums.routeviews.org."},
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
// No TXT RR
|
||||||
|
{"ns1.google.com.", "", []string{"ns2.google.com."},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var checkAuthoritativeNssTestsErr = []struct {
|
||||||
|
fqdn, value string
|
||||||
|
ns []string
|
||||||
|
error string
|
||||||
|
}{
|
||||||
|
// TXT RR /w unexpected value
|
||||||
|
{"8.8.8.8.asn.routeviews.org.", "fe01=", []string{"asnums.routeviews.org."},
|
||||||
|
"did not return the expected TXT record",
|
||||||
|
},
|
||||||
|
// No TXT RR
|
||||||
|
{"ns1.google.com.", "fe01=", []string{"ns2.google.com."},
|
||||||
|
"did not return the expected TXT record",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSValidServerResponse(t *testing.T) {
|
func TestDNSValidServerResponse(t *testing.T) {
|
||||||
preCheckDNS = func(domain, fqdn string) bool {
|
preCheckDNS = func(fqdn, value string) (bool, error) {
|
||||||
return true
|
return true, nil
|
||||||
}
|
}
|
||||||
privKey, _ := generatePrivateKey(rsakey, 512)
|
privKey, _ := generatePrivateKey(rsakey, 512)
|
||||||
|
|
||||||
|
@ -39,7 +101,80 @@ func TestDNSValidServerResponse(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPreCheckDNS(t *testing.T) {
|
func TestPreCheckDNS(t *testing.T) {
|
||||||
if !preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org") {
|
ok, err := preCheckDNS("acme-staging.api.letsencrypt.org", "fe01=")
|
||||||
|
if err != nil || !ok {
|
||||||
t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org")
|
t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLookupNameserversOK(t *testing.T) {
|
||||||
|
for _, tt := range lookupNameserversTestsOK {
|
||||||
|
nss, err := lookupNameservers(tt.fqdn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("#%s: got %q; want nil", tt.fqdn, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(nss)
|
||||||
|
sort.Strings(tt.nss)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(nss, tt.nss) {
|
||||||
|
t.Errorf("#%s: got %v; want %v", tt.fqdn, nss, tt.nss)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupNameserversErr(t *testing.T) {
|
||||||
|
for _, tt := range lookupNameserversTestsErr {
|
||||||
|
_, err := lookupNameservers(tt.fqdn)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), tt.error) {
|
||||||
|
t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAuthoritativeNss(t *testing.T) {
|
||||||
|
for _, tt := range checkAuthoritativeNssTests {
|
||||||
|
ok, _ := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Errorf("#%s: got %t; want %t", tt.fqdn, tt.ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAuthoritativeNssErr(t *testing.T) {
|
||||||
|
for _, tt := range checkAuthoritativeNssTestsErr {
|
||||||
|
_, err := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tt.error) {
|
||||||
|
t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitForTimeout(t *testing.T) {
|
||||||
|
c := make(chan error)
|
||||||
|
go func() {
|
||||||
|
err := waitFor(3, 1, func() (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
})
|
||||||
|
c <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
timeout := time.After(4 * time.Second)
|
||||||
|
select {
|
||||||
|
case <-timeout:
|
||||||
|
t.Fatal("timeout exceeded")
|
||||||
|
case err := <-c:
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected timeout error; got <nil>", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue