Refactor DNS precheck
This commit is contained in:
parent
602aeba6c1
commit
5992793edd
2 changed files with 44 additions and 42 deletions
|
@ -13,11 +13,9 @@ import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type preCheckDNSFunc func() bool
|
type preCheckDNSFunc func(domain, fqdn string) bool
|
||||||
|
|
||||||
var preCheckDNS = func() bool {
|
var preCheckDNS preCheckDNSFunc = checkDNS
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
var preCheckDNSFallbackCount = 5
|
var preCheckDNSFallbackCount = 5
|
||||||
|
|
||||||
|
@ -54,42 +52,7 @@ func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if preCheckDNS() {
|
preCheckDNS(domain, fqdn)
|
||||||
// check if the expected DNS entry was created. If not wait for some time and try again.
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetQuestion(domain+".", dns.TypeSOA)
|
|
||||||
c := new(dns.Client)
|
|
||||||
in, _, err := c.Exchange(m, "8.8.8.8:53")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var authorativeNS string
|
|
||||||
for _, answ := range in.Answer {
|
|
||||||
soa := answ.(*dns.SOA)
|
|
||||||
authorativeNS = soa.Ns
|
|
||||||
}
|
|
||||||
|
|
||||||
fallbackCnt := 0
|
|
||||||
for fallbackCnt < preCheckDNSFallbackCount {
|
|
||||||
m.SetQuestion(fqdn, dns.TypeTXT)
|
|
||||||
in, _, err = c.Exchange(m, authorativeNS+":53")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(in.Answer) > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
fallbackCnt++
|
|
||||||
if fallbackCnt >= preCheckDNSFallbackCount {
|
|
||||||
return errors.New("Could not retrieve the value from DNS in a timely manner. Aborting.")
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(time.Second * time.Duration(fallbackCnt))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
|
jsonBytes, err := json.Marshal(challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -138,3 +101,42 @@ Loop:
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkDNS(domain, fqdn string) bool {
|
||||||
|
// check if the expected DNS entry was created. If not wait for some time and try again.
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion(domain+".", dns.TypeSOA)
|
||||||
|
c := new(dns.Client)
|
||||||
|
in, _, err := c.Exchange(m, "8.8.8.8:53")
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var authorativeNS string
|
||||||
|
for _, answ := range in.Answer {
|
||||||
|
soa := answ.(*dns.SOA)
|
||||||
|
authorativeNS = soa.Ns
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbackCnt := 0
|
||||||
|
for fallbackCnt < preCheckDNSFallbackCount {
|
||||||
|
m.SetQuestion(fqdn, dns.TypeTXT)
|
||||||
|
in, _, err = c.Exchange(m, authorativeNS+":53")
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(in.Answer) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbackCnt++
|
||||||
|
if fallbackCnt >= preCheckDNSFallbackCount {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second * time.Duration(fallbackCnt))
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -11,8 +11,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSValidServerResponse(t *testing.T) {
|
func TestDNSValidServerResponse(t *testing.T) {
|
||||||
preCheckDNS = func() bool {
|
preCheckDNS = func(domain, fqdn string) bool {
|
||||||
return false
|
return true
|
||||||
}
|
}
|
||||||
privKey, _ := generatePrivateKey(rsakey, 512)
|
privKey, _ := generatePrivateKey(rsakey, 512)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue