From 94930d20ea241fddb5fa1668a08112dc1acc900d Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Fri, 30 Aug 2019 13:47:27 +0100 Subject: [PATCH] plugin/file: rework outgoing axfr (#3227) * plugin/file: rework outgoing axfr Signed-off-by: Miek Gieben * Fix test Signed-off-by: Miek Gieben * Actually properly test xfr Signed-off-by: Miek Gieben * Fix test Signed-off-by: Miek Gieben --- plugin/auto/watcher_test.go | 10 ++++----- plugin/file/reload_test.go | 16 ++++++++++---- plugin/file/xfr.go | 43 +++++++++++++++++++++++-------------- plugin/file/zone.go | 37 ++++++++++++------------------- test/auto_test.go | 18 ++++++++++------ 5 files changed, 70 insertions(+), 54 deletions(-) diff --git a/plugin/auto/watcher_test.go b/plugin/auto/watcher_test.go index a7013448b..be2a0d48b 100644 --- a/plugin/auto/watcher_test.go +++ b/plugin/auto/watcher_test.go @@ -30,12 +30,12 @@ func TestWatcher(t *testing.T) { a.Walk() - // example.org and example.com should exist - if x := len(a.Zones.Z["example.org."].All()); x != 4 { - t.Fatalf("Expected 4 RRs, got %d", x) + // example.org and example.com should exist, we have 3 apex rrs and 1 "real" record. All() returns the non-apex ones. + if x := len(a.Zones.Z["example.org."].All()); x != 1 { + t.Fatalf("Expected 1 RRs, got %d", x) } - if x := len(a.Zones.Z["example.com."].All()); x != 4 { - t.Fatalf("Expected 4 RRs, got %d", x) + if x := len(a.Zones.Z["example.com."].All()); x != 1 { + t.Fatalf("Expected 1 RRs, got %d", x) } // Now remove one file, rescan and see if it's gone. diff --git a/plugin/file/reload_test.go b/plugin/file/reload_test.go index 196565cac..f9e544372 100644 --- a/plugin/file/reload_test.go +++ b/plugin/file/reload_test.go @@ -48,8 +48,12 @@ func TestZoneReload(t *testing.T) { t.Fatalf("Failed to lookup, got %d", res) } - if len(z.All()) != 5 { - t.Fatalf("Expected 5 RRs, got %d", len(z.All())) + rrs, err := z.ApexIfDefined() // all apex records. + if err != nil { + t.Fatal(err) + } + if len(rrs) != 5 { + t.Fatalf("Expected 5 RRs, got %d", len(rrs)) } if err := ioutil.WriteFile(fileName, []byte(reloadZone2Test), 0644); err != nil { t.Fatalf("Failed to write new zone data: %s", err) @@ -57,8 +61,12 @@ func TestZoneReload(t *testing.T) { // Could still be racy, but we need to wait a bit for the event to be seen time.Sleep(1 * time.Second) - if len(z.All()) != 3 { - t.Fatalf("Expected 3 RRs, got %d", len(z.All())) + rrs, err = z.ApexIfDefined() + if err != nil { + t.Fatal(err) + } + if len(rrs) != 3 { + t.Fatalf("Expected 3 RRs, got %d", len(rrs)) } } diff --git a/plugin/file/xfr.go b/plugin/file/xfr.go index f5f803d11..659c5c9c3 100644 --- a/plugin/file/xfr.go +++ b/plugin/file/xfr.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/file/tree" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -35,8 +36,9 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in } } - records := x.All() - if len(records) == 0 { + // get soa and apex + apex, err := x.ApexIfDefined() + if err != nil { return dns.RcodeServerFailure, nil } @@ -49,23 +51,34 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in wg.Done() }() - j, l := 0, 0 - records = append(records, records[0]) // add closing SOA to the end - log.Infof("Outgoing transfer of %d records of zone %s to %s started with %d SOA serial", len(records), x.origin, state.IP(), x.SOASerialIfDefined()) - for i, r := range records { - l += dns.Len(r) - if l > transferLength { - ch <- &dns.Envelope{RR: records[j:i]} - l = 0 - j = i + rrs := []dns.RR{} + l := len(apex) + + ch <- &dns.Envelope{RR: apex} + + x.Walk(func(e *tree.Elem, _ map[uint16][]dns.RR) error { + rrs = append(rrs, e.All()...) + if len(rrs) > 500 { + ch <- &dns.Envelope{RR: rrs} + l += len(rrs) + rrs = []dns.RR{} } + return nil + }) + + if len(rrs) > 0 { + ch <- &dns.Envelope{RR: rrs} + l += len(rrs) + rrs = []dns.RR{} } - if j < len(records) { - ch <- &dns.Envelope{RR: records[j:]} - } + + ch <- &dns.Envelope{RR: []dns.RR{apex[0]}} // closing SOA. + l++ + close(ch) // Even though we close the channel here, we still have wg.Wait() // to wait before we can return and close the connection. + log.Infof("Outgoing transfer of %d records of zone %s to %s done with %d SOA serial", l, x.origin, state.IP(), apex[0].(*dns.SOA).Serial) return dns.RcodeSuccess, nil } @@ -103,5 +116,3 @@ func (x Xfr) ServeIxfr(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i } return dns.RcodeServerFailure, nil } - -const transferLength = 1000 // Start a new envelop after message reaches this size in bytes. Intentionally small to test multi envelope parsing. diff --git a/plugin/file/zone.go b/plugin/file/zone.go index 27c951774..62720abb4 100644 --- a/plugin/file/zone.go +++ b/plugin/file/zone.go @@ -154,36 +154,27 @@ func (z *Zone) TransferAllowed(state request.Request) bool { return false } -// All returns all records from the zone, the first record will be the SOA record, -// optionally followed by all RRSIG(SOA)s. -func (z *Zone) All() []dns.RR { - records := []dns.RR{} +// ApexIfDefined returns the apex nodes from z. The SOA record is the first record, if it does not exist, an error is returned. +func (z *Zone) ApexIfDefined() ([]dns.RR, error) { z.RLock() - allNodes := z.Tree.All() - z.RUnlock() - - for _, a := range allNodes { - records = append(records, a.All()...) + defer z.RUnlock() + if z.Apex.SOA == nil { + return nil, fmt.Errorf("no SOA") } - z.RLock() - if len(z.Apex.SIGNS) > 0 { - records = append(z.Apex.SIGNS, records...) - } - if len(z.Apex.NS) > 0 { - records = append(z.Apex.NS, records...) - } + rrs := []dns.RR{z.Apex.SOA} if len(z.Apex.SIGSOA) > 0 { - records = append(z.Apex.SIGSOA, records...) + rrs = append(rrs, z.Apex.SIGSOA...) + } + if len(z.Apex.NS) > 0 { + rrs = append(rrs, z.Apex.NS...) + } + if len(z.Apex.SIGNS) > 0 { + rrs = append(rrs, z.Apex.SIGNS...) } - if z.Apex.SOA != nil { - z.RUnlock() - return append([]dns.RR{z.Apex.SOA}, records...) - } - z.RUnlock() - return records + return rrs, nil } // NameFromRight returns the labels from the right, staring with the diff --git a/test/auto_test.go b/test/auto_test.go index 4d9b70a1c..07e2af12d 100644 --- a/test/auto_test.go +++ b/test/auto_test.go @@ -129,9 +129,9 @@ func TestAutoAXFR(t *testing.T) { t.Fatalf("Could not get CoreDNS serving instance: %s", err) } - udp, _ := CoreDNSServerPorts(i, 0) - if udp == "" { - t.Fatal("Could not get UDP listening port") + _, tcp := CoreDNSServerPorts(i, 0) + if tcp == "" { + t.Fatal("Could not get TCP listening port") } defer i.Stop() @@ -142,14 +142,20 @@ func TestAutoAXFR(t *testing.T) { time.Sleep(1100 * time.Millisecond) // wait for it to be picked up + tr := new(dns.Transfer) m := new(dns.Msg) m.SetAxfr("example.org.") - resp, err := dns.Exchange(m, udp) + c, err := tr.In(m, tcp) if err != nil { t.Fatal("Expected to receive reply, but didn't") } - if len(resp.Answer) != 5 { - t.Fatalf("Expected response with %d RRs, got %d", 5, len(resp.Answer)) + l := 0 + for e := range c { + l += len(e.RR) + } + + if l != 5 { + t.Fatalf("Expected response with %d RRs, got %d", 5, l) } }