plugin/file: rework outgoing axfr (#3227)
* plugin/file: rework outgoing axfr Signed-off-by: Miek Gieben <miek@miek.nl> * Fix test Signed-off-by: Miek Gieben <miek@miek.nl> * Actually properly test xfr Signed-off-by: Miek Gieben <miek@miek.nl> * Fix test Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
parent
b8a0b52a5e
commit
94930d20ea
5 changed files with 70 additions and 54 deletions
|
@ -30,12 +30,12 @@ func TestWatcher(t *testing.T) {
|
|||
|
||||
a.Walk()
|
||||
|
||||
// example.org and example.com should exist
|
||||
if x := len(a.Zones.Z["example.org."].All()); x != 4 {
|
||||
t.Fatalf("Expected 4 RRs, got %d", x)
|
||||
// example.org and example.com should exist, we have 3 apex rrs and 1 "real" record. All() returns the non-apex ones.
|
||||
if x := len(a.Zones.Z["example.org."].All()); x != 1 {
|
||||
t.Fatalf("Expected 1 RRs, got %d", x)
|
||||
}
|
||||
if x := len(a.Zones.Z["example.com."].All()); x != 4 {
|
||||
t.Fatalf("Expected 4 RRs, got %d", x)
|
||||
if x := len(a.Zones.Z["example.com."].All()); x != 1 {
|
||||
t.Fatalf("Expected 1 RRs, got %d", x)
|
||||
}
|
||||
|
||||
// Now remove one file, rescan and see if it's gone.
|
||||
|
|
|
@ -48,8 +48,12 @@ func TestZoneReload(t *testing.T) {
|
|||
t.Fatalf("Failed to lookup, got %d", res)
|
||||
}
|
||||
|
||||
if len(z.All()) != 5 {
|
||||
t.Fatalf("Expected 5 RRs, got %d", len(z.All()))
|
||||
rrs, err := z.ApexIfDefined() // all apex records.
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 5 {
|
||||
t.Fatalf("Expected 5 RRs, got %d", len(rrs))
|
||||
}
|
||||
if err := ioutil.WriteFile(fileName, []byte(reloadZone2Test), 0644); err != nil {
|
||||
t.Fatalf("Failed to write new zone data: %s", err)
|
||||
|
@ -57,8 +61,12 @@ func TestZoneReload(t *testing.T) {
|
|||
// Could still be racy, but we need to wait a bit for the event to be seen
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if len(z.All()) != 3 {
|
||||
t.Fatalf("Expected 3 RRs, got %d", len(z.All()))
|
||||
rrs, err = z.ApexIfDefined()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 3 {
|
||||
t.Fatalf("Expected 3 RRs, got %d", len(rrs))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/file/tree"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -35,8 +36,9 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in
|
|||
}
|
||||
}
|
||||
|
||||
records := x.All()
|
||||
if len(records) == 0 {
|
||||
// get soa and apex
|
||||
apex, err := x.ApexIfDefined()
|
||||
if err != nil {
|
||||
return dns.RcodeServerFailure, nil
|
||||
}
|
||||
|
||||
|
@ -49,23 +51,34 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in
|
|||
wg.Done()
|
||||
}()
|
||||
|
||||
j, l := 0, 0
|
||||
records = append(records, records[0]) // add closing SOA to the end
|
||||
log.Infof("Outgoing transfer of %d records of zone %s to %s started with %d SOA serial", len(records), x.origin, state.IP(), x.SOASerialIfDefined())
|
||||
for i, r := range records {
|
||||
l += dns.Len(r)
|
||||
if l > transferLength {
|
||||
ch <- &dns.Envelope{RR: records[j:i]}
|
||||
l = 0
|
||||
j = i
|
||||
rrs := []dns.RR{}
|
||||
l := len(apex)
|
||||
|
||||
ch <- &dns.Envelope{RR: apex}
|
||||
|
||||
x.Walk(func(e *tree.Elem, _ map[uint16][]dns.RR) error {
|
||||
rrs = append(rrs, e.All()...)
|
||||
if len(rrs) > 500 {
|
||||
ch <- &dns.Envelope{RR: rrs}
|
||||
l += len(rrs)
|
||||
rrs = []dns.RR{}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if len(rrs) > 0 {
|
||||
ch <- &dns.Envelope{RR: rrs}
|
||||
l += len(rrs)
|
||||
rrs = []dns.RR{}
|
||||
}
|
||||
if j < len(records) {
|
||||
ch <- &dns.Envelope{RR: records[j:]}
|
||||
}
|
||||
|
||||
ch <- &dns.Envelope{RR: []dns.RR{apex[0]}} // closing SOA.
|
||||
l++
|
||||
|
||||
close(ch) // Even though we close the channel here, we still have
|
||||
wg.Wait() // to wait before we can return and close the connection.
|
||||
|
||||
log.Infof("Outgoing transfer of %d records of zone %s to %s done with %d SOA serial", l, x.origin, state.IP(), apex[0].(*dns.SOA).Serial)
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
|
||||
|
@ -103,5 +116,3 @@ func (x Xfr) ServeIxfr(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i
|
|||
}
|
||||
return dns.RcodeServerFailure, nil
|
||||
}
|
||||
|
||||
const transferLength = 1000 // Start a new envelop after message reaches this size in bytes. Intentionally small to test multi envelope parsing.
|
||||
|
|
|
@ -154,36 +154,27 @@ func (z *Zone) TransferAllowed(state request.Request) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// All returns all records from the zone, the first record will be the SOA record,
|
||||
// optionally followed by all RRSIG(SOA)s.
|
||||
func (z *Zone) All() []dns.RR {
|
||||
records := []dns.RR{}
|
||||
// ApexIfDefined returns the apex nodes from z. The SOA record is the first record, if it does not exist, an error is returned.
|
||||
func (z *Zone) ApexIfDefined() ([]dns.RR, error) {
|
||||
z.RLock()
|
||||
allNodes := z.Tree.All()
|
||||
z.RUnlock()
|
||||
|
||||
for _, a := range allNodes {
|
||||
records = append(records, a.All()...)
|
||||
defer z.RUnlock()
|
||||
if z.Apex.SOA == nil {
|
||||
return nil, fmt.Errorf("no SOA")
|
||||
}
|
||||
|
||||
z.RLock()
|
||||
if len(z.Apex.SIGNS) > 0 {
|
||||
records = append(z.Apex.SIGNS, records...)
|
||||
}
|
||||
if len(z.Apex.NS) > 0 {
|
||||
records = append(z.Apex.NS, records...)
|
||||
}
|
||||
rrs := []dns.RR{z.Apex.SOA}
|
||||
|
||||
if len(z.Apex.SIGSOA) > 0 {
|
||||
records = append(z.Apex.SIGSOA, records...)
|
||||
rrs = append(rrs, z.Apex.SIGSOA...)
|
||||
}
|
||||
if len(z.Apex.NS) > 0 {
|
||||
rrs = append(rrs, z.Apex.NS...)
|
||||
}
|
||||
if len(z.Apex.SIGNS) > 0 {
|
||||
rrs = append(rrs, z.Apex.SIGNS...)
|
||||
}
|
||||
|
||||
if z.Apex.SOA != nil {
|
||||
z.RUnlock()
|
||||
return append([]dns.RR{z.Apex.SOA}, records...)
|
||||
}
|
||||
z.RUnlock()
|
||||
return records
|
||||
return rrs, nil
|
||||
}
|
||||
|
||||
// NameFromRight returns the labels from the right, staring with the
|
||||
|
|
|
@ -129,9 +129,9 @@ func TestAutoAXFR(t *testing.T) {
|
|||
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||
}
|
||||
|
||||
udp, _ := CoreDNSServerPorts(i, 0)
|
||||
if udp == "" {
|
||||
t.Fatal("Could not get UDP listening port")
|
||||
_, tcp := CoreDNSServerPorts(i, 0)
|
||||
if tcp == "" {
|
||||
t.Fatal("Could not get TCP listening port")
|
||||
}
|
||||
defer i.Stop()
|
||||
|
||||
|
@ -142,14 +142,20 @@ func TestAutoAXFR(t *testing.T) {
|
|||
|
||||
time.Sleep(1100 * time.Millisecond) // wait for it to be picked up
|
||||
|
||||
tr := new(dns.Transfer)
|
||||
m := new(dns.Msg)
|
||||
m.SetAxfr("example.org.")
|
||||
resp, err := dns.Exchange(m, udp)
|
||||
c, err := tr.In(m, tcp)
|
||||
if err != nil {
|
||||
t.Fatal("Expected to receive reply, but didn't")
|
||||
}
|
||||
if len(resp.Answer) != 5 {
|
||||
t.Fatalf("Expected response with %d RRs, got %d", 5, len(resp.Answer))
|
||||
l := 0
|
||||
for e := range c {
|
||||
l += len(e.RR)
|
||||
}
|
||||
|
||||
if l != 5 {
|
||||
t.Fatalf("Expected response with %d RRs, got %d", 5, l)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue