From 5992793edd44b2adbf095c1bc70245bb2a440b2c Mon Sep 17 00:00:00 2001 From: xenolf Date: Fri, 22 Jan 2016 02:25:27 +0100 Subject: [PATCH] Refactor DNS precheck --- acme/dns_challenge.go | 82 +++++++++++++++++++------------------- acme/dns_challenge_test.go | 4 +- 2 files changed, 44 insertions(+), 42 deletions(-) diff --git a/acme/dns_challenge.go b/acme/dns_challenge.go index 198a1bcb..aae977ca 100644 --- a/acme/dns_challenge.go +++ b/acme/dns_challenge.go @@ -13,11 +13,9 @@ import ( "github.com/miekg/dns" ) -type preCheckDNSFunc func() bool +type preCheckDNSFunc func(domain, fqdn string) bool -var preCheckDNS = func() bool { - return true -} +var preCheckDNS preCheckDNSFunc = checkDNS var preCheckDNSFallbackCount = 5 @@ -54,42 +52,7 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error { return err } - if preCheckDNS() { - // check if the expected DNS entry was created. If not wait for some time and try again. - m := new(dns.Msg) - m.SetQuestion(domain+".", dns.TypeSOA) - c := new(dns.Client) - in, _, err := c.Exchange(m, "8.8.8.8:53") - if err != nil { - return err - } - - var authorativeNS string - for _, answ := range in.Answer { - soa := answ.(*dns.SOA) - authorativeNS = soa.Ns - } - - fallbackCnt := 0 - for fallbackCnt < preCheckDNSFallbackCount { - m.SetQuestion(fqdn, dns.TypeTXT) - in, _, err = c.Exchange(m, authorativeNS+":53") - if err != nil { - return err - } - - if len(in.Answer) > 0 { - break - } - - fallbackCnt++ - if fallbackCnt >= preCheckDNSFallbackCount { - return errors.New("Could not retrieve the value from DNS in a timely manner. Aborting.") - } - - time.Sleep(time.Second * time.Duration(fallbackCnt)) - } - } + preCheckDNS(domain, fqdn) jsonBytes, err := json.Marshal(challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth}) if err != nil { @@ -138,3 +101,42 @@ Loop: return nil } + +func checkDNS(domain, fqdn string) bool { + // check if the expected DNS entry was created. If not wait for some time and try again. + m := new(dns.Msg) + m.SetQuestion(domain+".", dns.TypeSOA) + c := new(dns.Client) + in, _, err := c.Exchange(m, "8.8.8.8:53") + if err != nil { + return false + } + + var authorativeNS string + for _, answ := range in.Answer { + soa := answ.(*dns.SOA) + authorativeNS = soa.Ns + } + + fallbackCnt := 0 + for fallbackCnt < preCheckDNSFallbackCount { + m.SetQuestion(fqdn, dns.TypeTXT) + in, _, err = c.Exchange(m, authorativeNS+":53") + if err != nil { + return false + } + + if len(in.Answer) > 0 { + return true + } + + fallbackCnt++ + if fallbackCnt >= preCheckDNSFallbackCount { + return false + } + + time.Sleep(time.Second * time.Duration(fallbackCnt)) + } + + return false +} diff --git a/acme/dns_challenge_test.go b/acme/dns_challenge_test.go index e69d7c4b..6a76cf8f 100644 --- a/acme/dns_challenge_test.go +++ b/acme/dns_challenge_test.go @@ -11,8 +11,8 @@ import ( ) func TestDNSValidServerResponse(t *testing.T) { - preCheckDNS = func() bool { - return false + preCheckDNS = func(domain, fqdn string) bool { + return true } privKey, _ := generatePrivateKey(rsakey, 512)