forked from TrueCloudLab/lego
commit
ba64faa4e1
3 changed files with 256 additions and 42 deletions
|
@ -6,17 +6,18 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type preCheckDNSFunc func(domain, fqdn string) bool
|
||||
type preCheckDNSFunc func(fqdn, value string) (bool, error)
|
||||
|
||||
var preCheckDNS preCheckDNSFunc = checkDNS
|
||||
var preCheckDNS preCheckDNSFunc = checkDnsPropagation
|
||||
|
||||
var preCheckDNSFallbackCount = 5
|
||||
var recursiveNameserver = "google-public-dns-a.google.com"
|
||||
|
||||
// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
|
||||
func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
|
||||
|
@ -60,50 +61,125 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
|
|||
}
|
||||
}()
|
||||
|
||||
fqdn, _, _ := DNS01Record(domain, keyAuth)
|
||||
fqdn, value, _ := DNS01Record(domain, keyAuth)
|
||||
|
||||
preCheckDNS(domain, fqdn)
|
||||
logf("[INFO][%s] Checking DNS record propagation...", domain)
|
||||
|
||||
err = waitFor(30, 2, func() (bool, error) {
|
||||
return preCheckDNS(fqdn, value)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
|
||||
}
|
||||
|
||||
func checkDNS(domain, fqdn string) bool {
|
||||
// check if the expected DNS entry was created. If not wait for some time and try again.
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(domain+".", dns.TypeSOA)
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m, "google-public-dns-a.google.com:53")
|
||||
// checkDnsPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
|
||||
func checkDnsPropagation(fqdn, value string) (bool, error) {
|
||||
// Initial attempt to resolve at the recursive NS
|
||||
r, err := dnsQuery(fqdn, dns.TypeTXT, recursiveNameserver, true)
|
||||
if err != nil {
|
||||
return false
|
||||
return false, err
|
||||
}
|
||||
if r.Rcode != dns.RcodeSuccess {
|
||||
return false, fmt.Errorf("Could not resolve %s -> %s", fqdn, dns.RcodeToString[r.Rcode])
|
||||
}
|
||||
|
||||
var authorativeNS string
|
||||
for _, answ := range in.Answer {
|
||||
soa := answ.(*dns.SOA)
|
||||
authorativeNS = soa.Ns
|
||||
// If we see a CNAME here then use the alias
|
||||
for _, rr := range r.Answer {
|
||||
if cn, ok := rr.(*dns.CNAME); ok {
|
||||
if cn.Hdr.Name == fqdn {
|
||||
fqdn = cn.Target
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fallbackCnt := 0
|
||||
for fallbackCnt < preCheckDNSFallbackCount {
|
||||
m.SetQuestion(fqdn, dns.TypeTXT)
|
||||
in, _, err = c.Exchange(m, authorativeNS+":53")
|
||||
authoritativeNss, err := lookupNameservers(fqdn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return checkAuthoritativeNss(fqdn, value, authoritativeNss)
|
||||
}
|
||||
|
||||
// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
|
||||
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
|
||||
for _, ns := range nameservers {
|
||||
r, err := dnsQuery(fqdn, dns.TypeTXT, ns, false)
|
||||
if err != nil {
|
||||
return false
|
||||
return false, err
|
||||
}
|
||||
|
||||
if len(in.Answer) > 0 {
|
||||
return true
|
||||
if r.Rcode != dns.RcodeSuccess {
|
||||
return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
|
||||
}
|
||||
|
||||
fallbackCnt++
|
||||
if fallbackCnt >= preCheckDNSFallbackCount {
|
||||
return false
|
||||
var found bool
|
||||
for _, rr := range r.Answer {
|
||||
if txt, ok := rr.(*dns.TXT); ok {
|
||||
if strings.Join(txt.Txt, "") == value {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * time.Duration(fallbackCnt))
|
||||
if !found {
|
||||
return false, fmt.Errorf("NS %s did not return the expected TXT record", ns)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// dnsQuery sends a DNS query to the given nameserver.
|
||||
func dnsQuery(fqdn string, rtype uint16, nameserver string, recursive bool) (in *dns.Msg, err error) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(fqdn, rtype)
|
||||
m.SetEdns0(4096, false)
|
||||
|
||||
if !recursive {
|
||||
m.RecursionDesired = false
|
||||
}
|
||||
|
||||
in, err = dns.Exchange(m, net.JoinHostPort(nameserver, "53"))
|
||||
if err == dns.ErrTruncated {
|
||||
tcp := &dns.Client{Net: "tcp"}
|
||||
in, _, err = tcp.Exchange(m, nameserver)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// lookupNameservers returns the authoritative nameservers for the given fqdn.
|
||||
func lookupNameservers(fqdn string) ([]string, error) {
|
||||
var authoritativeNss []string
|
||||
|
||||
r, err := dnsQuery(fqdn, dns.TypeNS, recursiveNameserver, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, rr := range r.Answer {
|
||||
if ns, ok := rr.(*dns.NS); ok {
|
||||
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
|
||||
}
|
||||
}
|
||||
|
||||
if len(authoritativeNss) > 0 {
|
||||
return authoritativeNss, nil
|
||||
}
|
||||
|
||||
// Strip of the left most label to get the parent domain.
|
||||
offset, _ := dns.NextLabel(fqdn, 0)
|
||||
next := fqdn[offset:]
|
||||
if dns.CountLabel(next) < 2 {
|
||||
return nil, fmt.Errorf("Could not determine authoritative nameservers")
|
||||
}
|
||||
|
||||
return lookupNameservers(next)
|
||||
}
|
||||
|
||||
// toFqdn converts the name into a fqdn appending a trailing dot.
|
||||
|
@ -124,22 +200,25 @@ func unFqdn(name string) string {
|
|||
return name
|
||||
}
|
||||
|
||||
// waitFor polls the given function 'f', once per second, up to 'timeout' seconds.
|
||||
func waitFor(timeout int, f func() (bool, error)) error {
|
||||
start := time.Now().Second()
|
||||
// waitFor polls the given function 'f', once every 'interval' seconds, up to 'timeout' seconds.
|
||||
func waitFor(timeout, interval int, f func() (bool, error)) error {
|
||||
var lastErr string
|
||||
timeup := time.After(time.Duration(timeout) * time.Second)
|
||||
for {
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if delta := time.Now().Second() - start; delta >= timeout {
|
||||
return fmt.Errorf("Time limit exceeded (%d seconds)", delta)
|
||||
select {
|
||||
case <-timeup:
|
||||
return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr)
|
||||
default:
|
||||
}
|
||||
|
||||
stop, err := f()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stop {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
lastErr = err.Error()
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(interval) * time.Second)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ func (r *DNSProviderRoute53) changeRecord(action, fqdn, value string, ttl int) e
|
|||
return err
|
||||
}
|
||||
|
||||
return waitFor(90, func() (bool, error) {
|
||||
return waitFor(90, 5, func() (bool, error) {
|
||||
status, err := r.client.GetChange(resp.ChangeInfo.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
|
|
@ -6,13 +6,75 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var lookupNameserversTestsOK = []struct {
|
||||
fqdn string
|
||||
nss []string
|
||||
}{
|
||||
{"books.google.com.ng.",
|
||||
[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
|
||||
},
|
||||
{"www.google.com.",
|
||||
[]string{"ns1.google.com.", "ns2.google.com.", "ns3.google.com.", "ns4.google.com."},
|
||||
},
|
||||
{"physics.georgetown.edu.",
|
||||
[]string{"ns1.georgetown.edu.", "ns2.georgetown.edu.", "ns3.georgetown.edu."},
|
||||
},
|
||||
}
|
||||
|
||||
var lookupNameserversTestsErr = []struct {
|
||||
fqdn string
|
||||
error string
|
||||
}{
|
||||
// invalid tld
|
||||
{"_null.n0n0.",
|
||||
"Could not determine authoritative nameservers",
|
||||
},
|
||||
// invalid domain
|
||||
{"_null.com.",
|
||||
"Could not determine authoritative nameservers",
|
||||
},
|
||||
}
|
||||
|
||||
var checkAuthoritativeNssTests = []struct {
|
||||
fqdn, value string
|
||||
ns []string
|
||||
ok bool
|
||||
}{
|
||||
// TXT RR w/ expected value
|
||||
{"8.8.8.8.asn.routeviews.org.", "151698.8.8.024", []string{"asnums.routeviews.org."},
|
||||
true,
|
||||
},
|
||||
// No TXT RR
|
||||
{"ns1.google.com.", "", []string{"ns2.google.com."},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
var checkAuthoritativeNssTestsErr = []struct {
|
||||
fqdn, value string
|
||||
ns []string
|
||||
error string
|
||||
}{
|
||||
// TXT RR /w unexpected value
|
||||
{"8.8.8.8.asn.routeviews.org.", "fe01=", []string{"asnums.routeviews.org."},
|
||||
"did not return the expected TXT record",
|
||||
},
|
||||
// No TXT RR
|
||||
{"ns1.google.com.", "fe01=", []string{"ns2.google.com."},
|
||||
"did not return the expected TXT record",
|
||||
},
|
||||
}
|
||||
|
||||
func TestDNSValidServerResponse(t *testing.T) {
|
||||
preCheckDNS = func(domain, fqdn string) bool {
|
||||
return true
|
||||
preCheckDNS = func(fqdn, value string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
privKey, _ := generatePrivateKey(rsakey, 512)
|
||||
|
||||
|
@ -39,7 +101,80 @@ func TestDNSValidServerResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPreCheckDNS(t *testing.T) {
|
||||
if !preCheckDNS("api.letsencrypt.org", "acme-staging.api.letsencrypt.org") {
|
||||
ok, err := preCheckDNS("acme-staging.api.letsencrypt.org", "fe01=")
|
||||
if err != nil || !ok {
|
||||
t.Errorf("preCheckDNS failed for acme-staging.api.letsencrypt.org")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupNameserversOK(t *testing.T) {
|
||||
for _, tt := range lookupNameserversTestsOK {
|
||||
nss, err := lookupNameservers(tt.fqdn)
|
||||
if err != nil {
|
||||
t.Fatalf("#%s: got %q; want nil", tt.fqdn, err)
|
||||
}
|
||||
|
||||
sort.Strings(nss)
|
||||
sort.Strings(tt.nss)
|
||||
|
||||
if !reflect.DeepEqual(nss, tt.nss) {
|
||||
t.Errorf("#%s: got %v; want %v", tt.fqdn, nss, tt.nss)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupNameserversErr(t *testing.T) {
|
||||
for _, tt := range lookupNameserversTestsErr {
|
||||
_, err := lookupNameservers(tt.fqdn)
|
||||
if err == nil {
|
||||
t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), tt.error) {
|
||||
t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAuthoritativeNss(t *testing.T) {
|
||||
for _, tt := range checkAuthoritativeNssTests {
|
||||
ok, _ := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
|
||||
if ok != tt.ok {
|
||||
t.Errorf("#%s: got %t; want %t", tt.fqdn, tt.ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAuthoritativeNssErr(t *testing.T) {
|
||||
for _, tt := range checkAuthoritativeNssTestsErr {
|
||||
_, err := checkAuthoritativeNss(tt.fqdn, tt.value, tt.ns)
|
||||
if err == nil {
|
||||
t.Fatalf("#%s: expected %q (error); got <nil>", tt.fqdn, tt.error)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.error) {
|
||||
t.Errorf("#%s: expected %q (error); got %q", tt.fqdn, tt.error, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitForTimeout(t *testing.T) {
|
||||
c := make(chan error)
|
||||
go func() {
|
||||
err := waitFor(3, 1, func() (bool, error) {
|
||||
return false, nil
|
||||
})
|
||||
c <- err
|
||||
}()
|
||||
|
||||
timeout := time.After(4 * time.Second)
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("timeout exceeded")
|
||||
case err := <-c:
|
||||
if err == nil {
|
||||
t.Errorf("expected timeout error; got <nil>", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue