plugin/file: rework outgoing axfr (#3227)

* plugin/file: rework outgoing axfr

Signed-off-by: Miek Gieben <miek@miek.nl>

* Fix test

Signed-off-by: Miek Gieben <miek@miek.nl>

* Actually properly test xfr

Signed-off-by: Miek Gieben <miek@miek.nl>

* Fix test

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben 2019-08-30 13:47:27 +01:00 committed by GitHub
parent b8a0b52a5e
commit 94930d20ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 70 additions and 54 deletions

View file

@ -30,12 +30,12 @@ func TestWatcher(t *testing.T) {
a.Walk()
// example.org and example.com should exist
if x := len(a.Zones.Z["example.org."].All()); x != 4 {
t.Fatalf("Expected 4 RRs, got %d", x)
// example.org and example.com should exist, we have 3 apex rrs and 1 "real" record. All() returns the non-apex ones.
if x := len(a.Zones.Z["example.org."].All()); x != 1 {
t.Fatalf("Expected 1 RRs, got %d", x)
}
if x := len(a.Zones.Z["example.com."].All()); x != 4 {
t.Fatalf("Expected 4 RRs, got %d", x)
if x := len(a.Zones.Z["example.com."].All()); x != 1 {
t.Fatalf("Expected 1 RRs, got %d", x)
}
// Now remove one file, rescan and see if it's gone.

View file

@ -48,8 +48,12 @@ func TestZoneReload(t *testing.T) {
t.Fatalf("Failed to lookup, got %d", res)
}
if len(z.All()) != 5 {
t.Fatalf("Expected 5 RRs, got %d", len(z.All()))
rrs, err := z.ApexIfDefined() // all apex records.
if err != nil {
t.Fatal(err)
}
if len(rrs) != 5 {
t.Fatalf("Expected 5 RRs, got %d", len(rrs))
}
if err := ioutil.WriteFile(fileName, []byte(reloadZone2Test), 0644); err != nil {
t.Fatalf("Failed to write new zone data: %s", err)
@ -57,8 +61,12 @@ func TestZoneReload(t *testing.T) {
// Could still be racy, but we need to wait a bit for the event to be seen
time.Sleep(1 * time.Second)
if len(z.All()) != 3 {
t.Fatalf("Expected 3 RRs, got %d", len(z.All()))
rrs, err = z.ApexIfDefined()
if err != nil {
t.Fatal(err)
}
if len(rrs) != 3 {
t.Fatalf("Expected 3 RRs, got %d", len(rrs))
}
}

View file

@ -6,6 +6,7 @@ import (
"sync"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/file/tree"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
@ -35,8 +36,9 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in
}
}
records := x.All()
if len(records) == 0 {
// get soa and apex
apex, err := x.ApexIfDefined()
if err != nil {
return dns.RcodeServerFailure, nil
}
@ -49,23 +51,34 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in
wg.Done()
}()
j, l := 0, 0
records = append(records, records[0]) // add closing SOA to the end
log.Infof("Outgoing transfer of %d records of zone %s to %s started with %d SOA serial", len(records), x.origin, state.IP(), x.SOASerialIfDefined())
for i, r := range records {
l += dns.Len(r)
if l > transferLength {
ch <- &dns.Envelope{RR: records[j:i]}
l = 0
j = i
rrs := []dns.RR{}
l := len(apex)
ch <- &dns.Envelope{RR: apex}
x.Walk(func(e *tree.Elem, _ map[uint16][]dns.RR) error {
rrs = append(rrs, e.All()...)
if len(rrs) > 500 {
ch <- &dns.Envelope{RR: rrs}
l += len(rrs)
rrs = []dns.RR{}
}
return nil
})
if len(rrs) > 0 {
ch <- &dns.Envelope{RR: rrs}
l += len(rrs)
rrs = []dns.RR{}
}
if j < len(records) {
ch <- &dns.Envelope{RR: records[j:]}
}
ch <- &dns.Envelope{RR: []dns.RR{apex[0]}} // closing SOA.
l++
close(ch) // Even though we close the channel here, we still have
wg.Wait() // to wait before we can return and close the connection.
log.Infof("Outgoing transfer of %d records of zone %s to %s done with %d SOA serial", l, x.origin, state.IP(), apex[0].(*dns.SOA).Serial)
return dns.RcodeSuccess, nil
}
@ -103,5 +116,3 @@ func (x Xfr) ServeIxfr(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i
}
return dns.RcodeServerFailure, nil
}
const transferLength = 1000 // Start a new envelop after message reaches this size in bytes. Intentionally small to test multi envelope parsing.

View file

@ -154,36 +154,27 @@ func (z *Zone) TransferAllowed(state request.Request) bool {
return false
}
// All returns all records from the zone, the first record will be the SOA record,
// optionally followed by all RRSIG(SOA)s.
func (z *Zone) All() []dns.RR {
records := []dns.RR{}
// ApexIfDefined returns the apex nodes from z. The SOA record is the first record, if it does not exist, an error is returned.
func (z *Zone) ApexIfDefined() ([]dns.RR, error) {
z.RLock()
allNodes := z.Tree.All()
z.RUnlock()
for _, a := range allNodes {
records = append(records, a.All()...)
defer z.RUnlock()
if z.Apex.SOA == nil {
return nil, fmt.Errorf("no SOA")
}
z.RLock()
if len(z.Apex.SIGNS) > 0 {
records = append(z.Apex.SIGNS, records...)
}
if len(z.Apex.NS) > 0 {
records = append(z.Apex.NS, records...)
}
rrs := []dns.RR{z.Apex.SOA}
if len(z.Apex.SIGSOA) > 0 {
records = append(z.Apex.SIGSOA, records...)
rrs = append(rrs, z.Apex.SIGSOA...)
}
if len(z.Apex.NS) > 0 {
rrs = append(rrs, z.Apex.NS...)
}
if len(z.Apex.SIGNS) > 0 {
rrs = append(rrs, z.Apex.SIGNS...)
}
if z.Apex.SOA != nil {
z.RUnlock()
return append([]dns.RR{z.Apex.SOA}, records...)
}
z.RUnlock()
return records
return rrs, nil
}
// NameFromRight returns the labels from the right, staring with the

View file

@ -129,9 +129,9 @@ func TestAutoAXFR(t *testing.T) {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
udp, _ := CoreDNSServerPorts(i, 0)
if udp == "" {
t.Fatal("Could not get UDP listening port")
_, tcp := CoreDNSServerPorts(i, 0)
if tcp == "" {
t.Fatal("Could not get TCP listening port")
}
defer i.Stop()
@ -142,14 +142,20 @@ func TestAutoAXFR(t *testing.T) {
time.Sleep(1100 * time.Millisecond) // wait for it to be picked up
tr := new(dns.Transfer)
m := new(dns.Msg)
m.SetAxfr("example.org.")
resp, err := dns.Exchange(m, udp)
c, err := tr.In(m, tcp)
if err != nil {
t.Fatal("Expected to receive reply, but didn't")
}
if len(resp.Answer) != 5 {
t.Fatalf("Expected response with %d RRs, got %d", 5, len(resp.Answer))
l := 0
for e := range c {
l += len(e.RR)
}
if l != 5 {
t.Fatalf("Expected response with %d RRs, got %d", 5, l)
}
}