diff --git a/middleware/file/cname_test.go b/middleware/file/cname_test.go new file mode 100644 index 000000000..800020068 --- /dev/null +++ b/middleware/file/cname_test.go @@ -0,0 +1,98 @@ +package file + +import ( + "sort" + "strings" + "testing" + + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" + "github.com/miekg/coredns/middleware/test" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +func TestLookupCNAMEChain(t *testing.T) { + name := "example.org." + zone, err := Parse(strings.NewReader(dbExampleCNAME), name, "stdin") + if err != nil { + t.Fatalf("Expected no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{name: zone}, Names: []string{name}}} + ctx := context.TODO() + + for _, tc := range cnameTestCases { + m := tc.Msg() + + rec := dnsrecorder.New(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v\n", err) + return + } + + resp := rec.Msg + sort.Sort(test.RRSet(resp.Answer)) + sort.Sort(test.RRSet(resp.Ns)) + sort.Sort(test.RRSet(resp.Extra)) + + if !test.Header(t, tc, resp) { + t.Logf("%v\n", resp) + continue + } + + if !test.Section(t, tc, test.Answer, resp.Answer) { + t.Logf("%v\n", resp) + } + if !test.Section(t, tc, test.Ns, resp.Ns) { + t.Logf("%v\n", resp) + + } + if !test.Section(t, tc, test.Extra, resp.Extra) { + t.Logf("%v\n", resp) + } + } +} + +var cnameTestCases = []test.Case{ + { + Qname: "a.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("a.example.org. 1800 IN A 127.0.0.1"), + }, + }, + { + Qname: "www3.example.org.", Qtype: dns.TypeCNAME, + Answer: []dns.RR{ + test.CNAME("www3.example.org. 1800 IN CNAME www2.example.org."), + }, + }, + { + Qname: "www3.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("a.example.org. 1800 IN A 127.0.0.1"), + test.CNAME("www.example.org. 1800 IN CNAME a.example.org."), + test.CNAME("www1.example.org. 1800 IN CNAME www.example.org."), + test.CNAME("www2.example.org. 1800 IN CNAME www1.example.org."), + test.CNAME("www3.example.org. 1800 IN CNAME www2.example.org."), + }, + }, +} + +const dbExampleCNAME = ` +$TTL 30M +$ORIGIN example.org. +@ IN SOA linode.atoom.net. miek.miek.nl. ( + 1282630057 ; Serial + 4H ; Refresh + 1H ; Retry + 7D ; Expire + 4H ) ; Negative Cache TTL + +a IN A 127.0.0.1 +www3 IN CNAME www2 +www2 IN CNAME www1 +www1 IN CNAME www +www IN CNAME a +dangling IN CNAME foo` diff --git a/middleware/file/dnssec_test.go b/middleware/file/dnssec_test.go index 7f4a0916b..40e6429ea 100644 --- a/middleware/file/dnssec_test.go +++ b/middleware/file/dnssec_test.go @@ -58,6 +58,7 @@ var dnssecTestCases = []test.Case{ test.A("a.miek.nl. 1800 IN A 139.162.196.78"), test.RRSIG("a.miek.nl. 1800 IN RRSIG A 8 3 1800 20160426031301 20160327031301 12051 miek.nl. lxLotCjWZ3kihTxk="), test.CNAME("www.miek.nl. 1800 IN CNAME a.miek.nl."), + test.RRSIG("www.miek.nl. 1800 RRSIG CNAME 8 3 1800 20160426031301 20160327031301 12051 miek.nl. NVZmMJaypS+wDL2Lar4Zw1zF"), }, Extra: []dns.RR{ @@ -118,7 +119,7 @@ var dnssecTestCases = []test.Case{ func TestLookupDNSSEC(t *testing.T) { zone, err := Parse(strings.NewReader(dbMiekNLSigned), testzone, "stdin") if err != nil { - t.Fatalf("expect no error when reading zone, got %q", err) + t.Fatalf("Expected no error when reading zone, got %q", err) } fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} @@ -130,7 +131,7 @@ func TestLookupDNSSEC(t *testing.T) { rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { - t.Errorf("expected no error, got %v\n", err) + t.Errorf("Expected no error, got %v\n", err) return } diff --git a/middleware/file/ds_test.go b/middleware/file/ds_test.go index a7ba06263..b89c0e15c 100644 --- a/middleware/file/ds_test.go +++ b/middleware/file/ds_test.go @@ -54,7 +54,7 @@ var dsTestCases = []test.Case{ func TestLookupDS(t *testing.T) { zone, err := Parse(strings.NewReader(dbMiekNLDelegation), testzone, "stdin") if err != nil { - t.Fatalf("expect no error when reading zone, got %q", err) + t.Fatalf("Expected no error when reading zone, got %q", err) } fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} @@ -66,7 +66,7 @@ func TestLookupDS(t *testing.T) { rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { - t.Errorf("expected no error, got %v\n", err) + t.Errorf("Expected no error, got %v\n", err) return } diff --git a/middleware/file/lookup.go b/middleware/file/lookup.go index 95cd02e73..c47f1f5fa 100644 --- a/middleware/file/lookup.go +++ b/middleware/file/lookup.go @@ -118,8 +118,9 @@ func (z *Zone) Lookup(qname string, qtype uint16, do bool) ([]dns.RR, []dns.RR, // Found entire name. if found && shot { - if rrs := elem.Types(dns.TypeCNAME, qname); len(rrs) > 0 { - return z.searchCNAME(rrs, qtype, do) + // DNAME... + if rrs := elem.Types(dns.TypeCNAME); len(rrs) > 0 && qtype != dns.TypeCNAME { + return z.searchCNAME(elem, rrs, qtype, do) } rrs := elem.Types(qtype, qname) @@ -151,7 +152,7 @@ func (z *Zone) Lookup(qname string, qtype uint16, do bool) ([]dns.RR, []dns.RR, auth := []dns.RR{} if rrs := wildElem.Types(dns.TypeCNAME, qname); len(rrs) > 0 { - return z.searchCNAME(rrs, qtype, do) + return z.searchCNAME(wildElem, rrs, qtype, do) } rrs := wildElem.Types(qtype, qname) @@ -250,22 +251,61 @@ func (z *Zone) ns(do bool) []dns.RR { return z.Apex.NS } -func (z *Zone) searchCNAME(rrs []dns.RR, qtype uint16, do bool) ([]dns.RR, []dns.RR, []dns.RR, Result) { - elem, _ := z.Tree.Search(rrs[0].(*dns.CNAME).Target) +func (z *Zone) searchCNAME(elem *tree.Elem, rrs []dns.RR, qtype uint16, do bool) ([]dns.RR, []dns.RR, []dns.RR, Result) { + if do { + sigs := elem.Types(dns.TypeRRSIG) + sigs = signatureForSubType(sigs, dns.TypeCNAME) + if len(sigs) > 0 { + rrs = append(rrs, sigs...) + } + } + + elem, _ = z.Tree.Search(rrs[0].(*dns.CNAME).Target) if elem == nil { return rrs, nil, nil, Success } - // RECURSIVE SEARCH, up to 8 deep. Also: tests. + i := 0 + +Redo: + cname := elem.Types(dns.TypeCNAME) + if len(cname) > 0 { + rrs = append(rrs, cname...) + + if do { + sigs := elem.Types(dns.TypeRRSIG) + sigs = signatureForSubType(sigs, dns.TypeCNAME) + if len(sigs) > 0 { + rrs = append(rrs, sigs...) + } + } + elem, _ = z.Tree.Search(cname[0].(*dns.CNAME).Target) + if elem == nil { + return rrs, nil, nil, Success + } + + i++ + if i > maxChain { + return rrs, nil, nil, Success + } + + goto Redo + } + targets := cnameForType(elem.All(), qtype) - if do { - sigs := elem.Types(dns.TypeRRSIG) - sigs = signatureForSubType(sigs, qtype) - if len(sigs) > 0 { - targets = append(targets, sigs...) + if len(targets) > 0 { + rrs = append(rrs, targets...) + + if do { + sigs := elem.Types(dns.TypeRRSIG) + sigs = signatureForSubType(sigs, qtype) + if len(sigs) > 0 { + rrs = append(rrs, sigs...) + } } } - return append(rrs, targets...), nil, nil, Success + + return rrs, nil, nil, Success } func cnameForType(targets []dns.RR, origQtype uint16) []dns.RR { @@ -317,3 +357,5 @@ func (z *Zone) searchGlue(name string) []dns.RR { } return glue } + +const maxChain = 8