diff --git a/.golangci.toml b/.golangci.toml index 5ff9033e..ced70b32 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -57,7 +57,10 @@ text = "`(tlsFeatureExtensionOID|ocspMustStapleFeature)` is a global variable" [[issues.exclude-rules]] path = "challenge/dns01/nameserver.go" - text = "`(defaultNameservers|recursiveNameservers|dnsTimeout|fqdnToZone|muFqdnToZone)` is a global variable" + text = "`(defaultNameservers|recursiveNameservers|dnsTimeout|fqdnSoaCache|muFqdnSoaCache)` is a global variable" + [[issues.exclude-rules]] + path = "challenge/dns01/nameserver_test.go" + text = "`findXByFqdnTestCases` is a global variable" [[issues.exclude-rules]] path = "challenge/http01/domain_matcher.go" text = "string `Host` has \\d occurrences, make it a constant" diff --git a/challenge/dns01/nameserver.go b/challenge/dns01/nameserver.go index 03f1a8d1..6db6c589 100644 --- a/challenge/dns01/nameserver.go +++ b/challenge/dns01/nameserver.go @@ -16,8 +16,8 @@ const defaultResolvConf = "/etc/resolv.conf" var dnsTimeout = 10 * time.Second var ( - fqdnToZone = map[string]string{} - muFqdnToZone sync.Mutex + fqdnSoaCache = map[string]*soaCacheEntry{} + muFqdnSoaCache sync.Mutex ) var defaultNameservers = []string{ @@ -28,11 +28,31 @@ var defaultNameservers = []string{ // recursiveNameservers are used to pre-check DNS propagation var recursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers) +// soaCacheEntry holds a cached SOA record (only selected fields) +type soaCacheEntry struct { + zone string // zone apex (a domain name) + primaryNs string // primary nameserver for the zone apex + expires time.Time // time when this cache entry should be evicted +} + +func newSoaCacheEntry(soa *dns.SOA) *soaCacheEntry { + return &soaCacheEntry{ + zone: soa.Hdr.Name, + primaryNs: soa.Ns, + expires: time.Now().Add(time.Duration(soa.Refresh) * time.Second), + } +} + +// isExpired checks whether a cache entry should be considered expired. +func (cache *soaCacheEntry) isExpired() bool { + return time.Now().After(cache.expires) +} + // ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing. func ClearFqdnCache() { - muFqdnToZone.Lock() - fqdnToZone = map[string]string{} - muFqdnToZone.Unlock() + muFqdnSoaCache.Lock() + fqdnSoaCache = map[string]*soaCacheEntry{} + muFqdnSoaCache.Unlock() } func AddDNSTimeout(timeout time.Duration) ChallengeOption { @@ -98,6 +118,22 @@ func lookupNameservers(fqdn string) ([]string, error) { return nil, fmt.Errorf("could not determine authoritative nameservers") } +// FindPrimaryNsByFqdn determines the primary nameserver of the zone apex for the given fqdn +// by recursing up the domain labels until the nameserver returns a SOA record in the answer section. +func FindPrimaryNsByFqdn(fqdn string) (string, error) { + return FindPrimaryNsByFqdnCustom(fqdn, recursiveNameservers) +} + +// FindPrimaryNsByFqdnCustom determines the primary nameserver of the zone apex for the given fqdn +// by recursing up the domain labels until the nameserver returns a SOA record in the answer section. +func FindPrimaryNsByFqdnCustom(fqdn string, nameservers []string) (string, error) { + soa, err := lookupSoaByFqdn(fqdn, nameservers) + if err != nil { + return "", err + } + return soa.primaryNs, nil +} + // FindZoneByFqdn determines the zone apex for the given fqdn // by recursing up the domain labels until the nameserver returns a SOA record in the answer section. func FindZoneByFqdn(fqdn string) (string, error) { @@ -107,14 +143,32 @@ func FindZoneByFqdn(fqdn string) (string, error) { // FindZoneByFqdnCustom determines the zone apex for the given fqdn // by recursing up the domain labels until the nameserver returns a SOA record in the answer section. func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) { - muFqdnToZone.Lock() - defer muFqdnToZone.Unlock() + soa, err := lookupSoaByFqdn(fqdn, nameservers) + if err != nil { + return "", err + } + return soa.zone, nil +} - // Do we have it cached? - if zone, ok := fqdnToZone[fqdn]; ok { - return zone, nil +func lookupSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) { + muFqdnSoaCache.Lock() + defer muFqdnSoaCache.Unlock() + + // Do we have it cached and is it still fresh? + if ent := fqdnSoaCache[fqdn]; ent != nil && !ent.isExpired() { + return ent, nil } + ent, err := fetchSoaByFqdn(fqdn, nameservers) + if err != nil { + return nil, err + } + + fqdnSoaCache[fqdn] = ent + return ent, nil +} + +func fetchSoaByFqdn(fqdn string, nameservers []string) (*soaCacheEntry, error) { var err error var in *dns.Msg @@ -134,7 +188,6 @@ func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) { switch in.Rcode { case dns.RcodeSuccess: // Check if we got a SOA RR in the answer section - if len(in.Answer) == 0 { continue } @@ -147,20 +200,18 @@ func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) { for _, ans := range in.Answer { if soa, ok := ans.(*dns.SOA); ok { - zone := soa.Hdr.Name - fqdnToZone[fqdn] = zone - return zone, nil + return newSoaCacheEntry(soa), nil } } case dns.RcodeNameError: // NXDOMAIN default: // Any response code other than NOERROR and NXDOMAIN is treated as error - return "", fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain) + return nil, fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain) } } - return "", fmt.Errorf("could not find the start of authority for %s%s", fqdn, formatDNSError(in, err)) + return nil, fmt.Errorf("could not find the start of authority for %s%s", fqdn, formatDNSError(in, err)) } // dnsMsgContainsCNAME checks for a CNAME answer in msg diff --git a/challenge/dns01/nameserver_test.go b/challenge/dns01/nameserver_test.go index 37b73c78..740cb220 100644 --- a/challenge/dns01/nameserver_test.go +++ b/challenge/dns01/nameserver_test.go @@ -68,68 +68,74 @@ func TestLookupNameserversErr(t *testing.T) { } } -func TestFindZoneByFqdnCustom(t *testing.T) { - testCases := []struct { - desc string - fqdn string - zone string - nameservers []string - expectedError string - }{ - { - desc: "domain is a CNAME", - fqdn: "mail.google.com.", - zone: "google.com.", - nameservers: recursiveNameservers, - }, - { - desc: "domain is a non-existent subdomain", - fqdn: "foo.google.com.", - zone: "google.com.", - nameservers: recursiveNameservers, - }, - { - desc: "domain is a eTLD", - fqdn: "example.com.ac.", - zone: "ac.", - nameservers: recursiveNameservers, - }, - { - desc: "domain is a cross-zone CNAME", - fqdn: "cross-zone-example.assets.sh.", - zone: "assets.sh.", - nameservers: recursiveNameservers, - }, - { - desc: "NXDOMAIN", - fqdn: "test.loho.jkl.", - zone: "loho.jkl.", - nameservers: []string{"1.1.1.1:53"}, - expectedError: "could not find the start of authority for test.loho.jkl.: NXDOMAIN", - }, - { - desc: "several non existent nameservers", - fqdn: "mail.google.com.", - zone: "google.com.", - nameservers: []string{":7053", ":8053", "1.1.1.1:53"}, - }, - { - desc: "only non existent nameservers", - fqdn: "mail.google.com.", - zone: "google.com.", - nameservers: []string{":7053", ":8053", ":9053"}, - expectedError: "could not find the start of authority for mail.google.com.: read udp", - }, - { - desc: "no nameservers", - fqdn: "test.ldez.com.", - zone: "ldez.com.", - nameservers: []string{}, - expectedError: "could not find the start of authority for test.ldez.com.", - }, - } +var findXByFqdnTestCases = []struct { + desc string + fqdn string + zone string + primaryNs string + nameservers []string + expectedError string +}{ + { + desc: "domain is a CNAME", + fqdn: "mail.google.com.", + zone: "google.com.", + primaryNs: "ns1.google.com.", + nameservers: recursiveNameservers, + }, + { + desc: "domain is a non-existent subdomain", + fqdn: "foo.google.com.", + zone: "google.com.", + primaryNs: "ns1.google.com.", + nameservers: recursiveNameservers, + }, + { + desc: "domain is a eTLD", + fqdn: "example.com.ac.", + zone: "ac.", + primaryNs: "a0.nic.ac.", + nameservers: recursiveNameservers, + }, + { + desc: "domain is a cross-zone CNAME", + fqdn: "cross-zone-example.assets.sh.", + zone: "assets.sh.", + primaryNs: "gina.ns.cloudflare.com.", + nameservers: recursiveNameservers, + }, + { + desc: "NXDOMAIN", + fqdn: "test.loho.jkl.", + zone: "loho.jkl.", + nameservers: []string{"1.1.1.1:53"}, + expectedError: "could not find the start of authority for test.loho.jkl.: NXDOMAIN", + }, + { + desc: "several non existent nameservers", + fqdn: "mail.google.com.", + zone: "google.com.", + primaryNs: "ns1.google.com.", + nameservers: []string{":7053", ":8053", "1.1.1.1:53"}, + }, + { + desc: "only non existent nameservers", + fqdn: "mail.google.com.", + zone: "google.com.", + nameservers: []string{":7053", ":8053", ":9053"}, + expectedError: "could not find the start of authority for mail.google.com.: read udp", + }, + { + desc: "no nameservers", + fqdn: "test.ldez.com.", + zone: "ldez.com.", + nameservers: []string{}, + expectedError: "could not find the start of authority for test.ldez.com.", + }, +} - for _, test := range testCases { +func TestFindZoneByFqdnCustom(t *testing.T) { + for _, test := range findXByFqdnTestCases { t.Run(test.desc, func(t *testing.T) { ClearFqdnCache() @@ -145,6 +151,23 @@ func TestFindZoneByFqdnCustom(t *testing.T) { } } +func TestFindPrimayNsByFqdnCustom(t *testing.T) { + for _, test := range findXByFqdnTestCases { + t.Run(test.desc, func(t *testing.T) { + ClearFqdnCache() + + ns, err := FindPrimaryNsByFqdnCustom(test.fqdn, test.nameservers) + if test.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), test.expectedError) + } else { + require.NoError(t, err) + assert.Equal(t, test.primaryNs, ns) + } + }) + } +} + func TestResolveConfServers(t *testing.T) { var testCases = []struct { fixture string