plugin/file: rework outgoing axfr (#3227)

* plugin/file: rework outgoing axfr

Signed-off-by: Miek Gieben <miek@miek.nl>

* Fix test

Signed-off-by: Miek Gieben <miek@miek.nl>

* Actually properly test xfr

Signed-off-by: Miek Gieben <miek@miek.nl>

* Fix test

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben 2019-08-30 13:47:27 +01:00 committed by GitHub
parent b8a0b52a5e
commit 94930d20ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 70 additions and 54 deletions

View file

@ -30,12 +30,12 @@ func TestWatcher(t *testing.T) {
a.Walk() a.Walk()
// example.org and example.com should exist // example.org and example.com should exist, we have 3 apex rrs and 1 "real" record. All() returns the non-apex ones.
if x := len(a.Zones.Z["example.org."].All()); x != 4 { if x := len(a.Zones.Z["example.org."].All()); x != 1 {
t.Fatalf("Expected 4 RRs, got %d", x) t.Fatalf("Expected 1 RRs, got %d", x)
} }
if x := len(a.Zones.Z["example.com."].All()); x != 4 { if x := len(a.Zones.Z["example.com."].All()); x != 1 {
t.Fatalf("Expected 4 RRs, got %d", x) t.Fatalf("Expected 1 RRs, got %d", x)
} }
// Now remove one file, rescan and see if it's gone. // Now remove one file, rescan and see if it's gone.

View file

@ -48,8 +48,12 @@ func TestZoneReload(t *testing.T) {
t.Fatalf("Failed to lookup, got %d", res) t.Fatalf("Failed to lookup, got %d", res)
} }
if len(z.All()) != 5 { rrs, err := z.ApexIfDefined() // all apex records.
t.Fatalf("Expected 5 RRs, got %d", len(z.All())) if err != nil {
t.Fatal(err)
}
if len(rrs) != 5 {
t.Fatalf("Expected 5 RRs, got %d", len(rrs))
} }
if err := ioutil.WriteFile(fileName, []byte(reloadZone2Test), 0644); err != nil { if err := ioutil.WriteFile(fileName, []byte(reloadZone2Test), 0644); err != nil {
t.Fatalf("Failed to write new zone data: %s", err) t.Fatalf("Failed to write new zone data: %s", err)
@ -57,8 +61,12 @@ func TestZoneReload(t *testing.T) {
// Could still be racy, but we need to wait a bit for the event to be seen // Could still be racy, but we need to wait a bit for the event to be seen
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
if len(z.All()) != 3 { rrs, err = z.ApexIfDefined()
t.Fatalf("Expected 3 RRs, got %d", len(z.All())) if err != nil {
t.Fatal(err)
}
if len(rrs) != 3 {
t.Fatalf("Expected 3 RRs, got %d", len(rrs))
} }
} }

View file

@ -6,6 +6,7 @@ import (
"sync" "sync"
"github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/file/tree"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -35,8 +36,9 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in
} }
} }
records := x.All() // get soa and apex
if len(records) == 0 { apex, err := x.ApexIfDefined()
if err != nil {
return dns.RcodeServerFailure, nil return dns.RcodeServerFailure, nil
} }
@ -49,23 +51,34 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in
wg.Done() wg.Done()
}() }()
j, l := 0, 0 rrs := []dns.RR{}
records = append(records, records[0]) // add closing SOA to the end l := len(apex)
log.Infof("Outgoing transfer of %d records of zone %s to %s started with %d SOA serial", len(records), x.origin, state.IP(), x.SOASerialIfDefined())
for i, r := range records { ch <- &dns.Envelope{RR: apex}
l += dns.Len(r)
if l > transferLength { x.Walk(func(e *tree.Elem, _ map[uint16][]dns.RR) error {
ch <- &dns.Envelope{RR: records[j:i]} rrs = append(rrs, e.All()...)
l = 0 if len(rrs) > 500 {
j = i ch <- &dns.Envelope{RR: rrs}
l += len(rrs)
rrs = []dns.RR{}
} }
return nil
})
if len(rrs) > 0 {
ch <- &dns.Envelope{RR: rrs}
l += len(rrs)
rrs = []dns.RR{}
} }
if j < len(records) {
ch <- &dns.Envelope{RR: records[j:]} ch <- &dns.Envelope{RR: []dns.RR{apex[0]}} // closing SOA.
} l++
close(ch) // Even though we close the channel here, we still have close(ch) // Even though we close the channel here, we still have
wg.Wait() // to wait before we can return and close the connection. wg.Wait() // to wait before we can return and close the connection.
log.Infof("Outgoing transfer of %d records of zone %s to %s done with %d SOA serial", l, x.origin, state.IP(), apex[0].(*dns.SOA).Serial)
return dns.RcodeSuccess, nil return dns.RcodeSuccess, nil
} }
@ -103,5 +116,3 @@ func (x Xfr) ServeIxfr(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i
} }
return dns.RcodeServerFailure, nil return dns.RcodeServerFailure, nil
} }
const transferLength = 1000 // Start a new envelop after message reaches this size in bytes. Intentionally small to test multi envelope parsing.

View file

@ -154,36 +154,27 @@ func (z *Zone) TransferAllowed(state request.Request) bool {
return false return false
} }
// All returns all records from the zone, the first record will be the SOA record, // ApexIfDefined returns the apex nodes from z. The SOA record is the first record, if it does not exist, an error is returned.
// optionally followed by all RRSIG(SOA)s. func (z *Zone) ApexIfDefined() ([]dns.RR, error) {
func (z *Zone) All() []dns.RR {
records := []dns.RR{}
z.RLock() z.RLock()
allNodes := z.Tree.All() defer z.RUnlock()
z.RUnlock() if z.Apex.SOA == nil {
return nil, fmt.Errorf("no SOA")
for _, a := range allNodes {
records = append(records, a.All()...)
} }
z.RLock() rrs := []dns.RR{z.Apex.SOA}
if len(z.Apex.SIGNS) > 0 {
records = append(z.Apex.SIGNS, records...)
}
if len(z.Apex.NS) > 0 {
records = append(z.Apex.NS, records...)
}
if len(z.Apex.SIGSOA) > 0 { if len(z.Apex.SIGSOA) > 0 {
records = append(z.Apex.SIGSOA, records...) rrs = append(rrs, z.Apex.SIGSOA...)
}
if len(z.Apex.NS) > 0 {
rrs = append(rrs, z.Apex.NS...)
}
if len(z.Apex.SIGNS) > 0 {
rrs = append(rrs, z.Apex.SIGNS...)
} }
if z.Apex.SOA != nil { return rrs, nil
z.RUnlock()
return append([]dns.RR{z.Apex.SOA}, records...)
}
z.RUnlock()
return records
} }
// NameFromRight returns the labels from the right, staring with the // NameFromRight returns the labels from the right, staring with the

View file

@ -129,9 +129,9 @@ func TestAutoAXFR(t *testing.T) {
t.Fatalf("Could not get CoreDNS serving instance: %s", err) t.Fatalf("Could not get CoreDNS serving instance: %s", err)
} }
udp, _ := CoreDNSServerPorts(i, 0) _, tcp := CoreDNSServerPorts(i, 0)
if udp == "" { if tcp == "" {
t.Fatal("Could not get UDP listening port") t.Fatal("Could not get TCP listening port")
} }
defer i.Stop() defer i.Stop()
@ -142,14 +142,20 @@ func TestAutoAXFR(t *testing.T) {
time.Sleep(1100 * time.Millisecond) // wait for it to be picked up time.Sleep(1100 * time.Millisecond) // wait for it to be picked up
tr := new(dns.Transfer)
m := new(dns.Msg) m := new(dns.Msg)
m.SetAxfr("example.org.") m.SetAxfr("example.org.")
resp, err := dns.Exchange(m, udp) c, err := tr.In(m, tcp)
if err != nil { if err != nil {
t.Fatal("Expected to receive reply, but didn't") t.Fatal("Expected to receive reply, but didn't")
} }
if len(resp.Answer) != 5 { l := 0
t.Fatalf("Expected response with %d RRs, got %d", 5, len(resp.Answer)) for e := range c {
l += len(e.RR)
}
if l != 5 {
t.Fatalf("Expected response with %d RRs, got %d", 5, l)
} }
} }