From c9d8a57ed66f8debe72856ec9d12c25f97248f0e Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Fri, 15 Apr 2016 14:26:27 +0100 Subject: [PATCH] Reload zone when a write is detected (#122) Zone reloading --- core/setup/file.go | 26 +++++---- core/setup/secondary.go | 2 +- middleware/file/README.md | 4 +- middleware/file/file.go | 17 +----- middleware/file/notify.go | 2 +- middleware/file/reload_test.go | 62 +++++++++++++++++++++ middleware/file/secondary.go | 29 +++------- middleware/file/secondary_test.go | 6 +-- middleware/file/xfr.go | 2 +- middleware/file/zone.go | 89 ++++++++++++++++++++++++++++--- 10 files changed, 179 insertions(+), 60 deletions(-) create mode 100644 middleware/file/reload_test.go diff --git a/core/setup/file.go b/core/setup/file.go index 189977317..1fe04c1ec 100644 --- a/core/setup/file.go +++ b/core/setup/file.go @@ -16,18 +16,17 @@ func File(c *Controller) (middleware.Middleware, error) { return nil, err } - // Add startup functions to notify the master. + // Add startup functions to notify the master(s). for _, n := range zones.Names { - if len(zones.Z[n].TransferTo) > 0 { - c.Startup = append(c.Startup, func() error { - zones.Z[n].StartupOnce.Do(func() { - if len(zones.Z[n].TransferTo) > 0 { - zones.Z[n].Notify() - } - }) - return nil + c.Startup = append(c.Startup, func() error { + zones.Z[n].StartupOnce.Do(func() { + if len(zones.Z[n].TransferTo) > 0 { + zones.Z[n].Notify() + } + zones.Z[n].Reload(nil) }) - } + return nil + }) } return func(next middleware.Handler) middleware.Handler { @@ -67,17 +66,24 @@ func fileParse(c *Controller) (file.Zones, error) { names = append(names, origins[i]) } + noReload := false for c.NextBlock() { t, _, e := parseTransfer(c) if e != nil { return file.Zones{}, e } + switch c.Val() { + case "no_reload": + noReload = true + } // discard from, here, maybe check and show log when we do? for _, origin := range origins { if t != nil { z[origin].TransferTo = append(z[origin].TransferTo, t...) } + z[origin].NoReload = noReload } + } } } diff --git a/core/setup/secondary.go b/core/setup/secondary.go index 5f1ad3a17..0abf82c04 100644 --- a/core/setup/secondary.go +++ b/core/setup/secondary.go @@ -47,7 +47,7 @@ func secondaryParse(c *Controller) (file.Zones, error) { } for i, _ := range origins { origins[i] = middleware.Host(origins[i]).Normalize() - z[origins[i]] = file.NewZone(origins[i]) + z[origins[i]] = file.NewZone(origins[i], "stdin") names = append(names, origins[i]) } diff --git a/middleware/file/README.md b/middleware/file/README.md index f0bf58101..dc053dd7c 100644 --- a/middleware/file/README.md +++ b/middleware/file/README.md @@ -24,13 +24,15 @@ TSIG key information, something like `transfer out [address...] key [name] [base file dbfile [zones... ] { transfer from [address...] transfer to [address...] - + no_reload } ~~~ * `transfer` enables zone transfers. It may be specified multiples times. *To* or *from* signals the direction. Addresses must be denoted in CIDR notation (127.0.0.1/32 etc.) or just as plain address. The special wildcard "*" means: the entire internet (only valid for 'transfer to'). +* `no_reload` by default CoreDNS will reload a zone from disk whenever it detects a change to the + file. This option disables that behavior. ## Examples diff --git a/middleware/file/file.go b/middleware/file/file.go index 50ae3fd26..e44fa4cdb 100644 --- a/middleware/file/file.go +++ b/middleware/file/file.go @@ -105,27 +105,14 @@ func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i // Parse parses the zone in filename and returns a new Zone or an error. func Parse(f io.Reader, origin, fileName string) (*Zone, error) { tokens := dns.ParseZone(f, dns.Fqdn(origin), fileName) - z := NewZone(origin) + z := NewZone(origin, fileName) for x := range tokens { if x.Error != nil { log.Printf("[ERROR] Failed to parse `%s': %v", origin, x.Error) return nil, x.Error } - switch h := x.RR.Header().Rrtype; h { - case dns.TypeSOA: - z.SOA = x.RR.(*dns.SOA) - case dns.TypeNSEC3, dns.TypeNSEC3PARAM: - err := fmt.Errorf("NSEC3 zone is not supported, dropping") - log.Printf("[ERROR] Failed to parse `%s': %v", origin, err) + if err := z.Insert(x.RR); err != nil { return nil, err - case dns.TypeRRSIG: - if x, ok := x.RR.(*dns.RRSIG); ok && x.TypeCovered == dns.TypeSOA { - z.SIG = append(z.SIG, x) - continue - } - fallthrough - default: - z.Insert(x.RR) } } return z, nil diff --git a/middleware/file/notify.go b/middleware/file/notify.go index b369f6ad1..8a2581b84 100644 --- a/middleware/file/notify.go +++ b/middleware/file/notify.go @@ -30,7 +30,7 @@ func (z *Zone) isNotify(state middleware.State) bool { // Notify will send notifies to all configured TransferTo IP addresses. func (z *Zone) Notify() { - go notify(z.name, z.TransferTo) + go notify(z.origin, z.TransferTo) } // notify sends notifies to the configured remote servers. It will try up to three times diff --git a/middleware/file/reload_test.go b/middleware/file/reload_test.go new file mode 100644 index 000000000..1769c701e --- /dev/null +++ b/middleware/file/reload_test.go @@ -0,0 +1,62 @@ +package file + +import ( + "io/ioutil" + "os" + "testing" + "time" + + "github.com/miekg/coredns/middleware/test" + "github.com/miekg/dns" +) + +func TestZoneReload(t *testing.T) { + fileName, rm, err := test.Zone(t, ".", reloadZoneTest) + if err != nil { + t.Fatalf("failed to create zone: %s", err) + } + defer rm() + reader, err := os.Open(fileName) + if err != nil { + t.Fatalf("failed to open zone: %s", err) + } + z, err := Parse(reader, "miek.nl", fileName) + if err != nil { + t.Fatalf("failed to parse zone: %s", err) + } + + z.Reload(nil) + + if _, _, _, res := z.Lookup("miek.nl.", dns.TypeSOA, false); res != Success { + t.Fatalf("failed to lookup, got %d", res) + } + + if _, _, _, res := z.Lookup("miek.nl.", dns.TypeNS, false); res != Success { + t.Fatalf("failed to lookup, got %d", res) + } + + if len(z.All()) != 5 { + t.Fatalf("expected 5 RRs, got %d", len(z.All())) + } + if err := ioutil.WriteFile(fileName, []byte(reloadZone2Test), 0644); err != nil { + t.Fatalf("failed to write new zone data", err) + } + // Could still be racy, but we need to wait a bit for the event to be seen + time.Sleep(1 * time.Second) + + if len(z.All()) != 3 { + t.Fatalf("expected 3 RRs, got %d", len(z.All())) + } +} + +const reloadZoneTest = `miek.nl. 1627 IN SOA linode.atoom.net. miek.miek.nl. 1460175181 14400 3600 604800 14400 +miek.nl. 1627 IN NS ext.ns.whyscream.net. +miek.nl. 1627 IN NS omval.tednet.nl. +miek.nl. 1627 IN NS linode.atoom.net. +miek.nl. 1627 IN NS ns-ext.nlnetlabs.nl. +` + +const reloadZone2Test = `miek.nl. 1627 IN SOA linode.atoom.net. miek.miek.nl. 1460175181 14400 3600 604800 14400 +miek.nl. 1627 IN NS ext.ns.whyscream.net. +miek.nl. 1627 IN NS omval.tednet.nl. +` diff --git a/middleware/file/secondary.go b/middleware/file/secondary.go index 9b3886a36..eb95392ad 100644 --- a/middleware/file/secondary.go +++ b/middleware/file/secondary.go @@ -1,7 +1,6 @@ package file import ( - "fmt" "log" "time" @@ -16,7 +15,7 @@ func (z *Zone) TransferIn() error { return nil } m := new(dns.Msg) - m.SetAxfr(z.name) + m.SetAxfr(z.origin) z1 := z.Copy() var ( @@ -29,32 +28,20 @@ Transfer: t := new(dns.Transfer) c, err := t.In(m, tr) if err != nil { - log.Printf("[ERROR] Failed to setup transfer `%s' with `%s': %v", z.name, tr, err) + log.Printf("[ERROR] Failed to setup transfer `%s' with `%s': %v", z.origin, tr, err) Err = err continue Transfer } for env := range c { if env.Error != nil { - log.Printf("[ERROR] Failed to parse transfer `%s': %v", z.name, env.Error) + log.Printf("[ERROR] Failed to parse transfer `%s': %v", z.origin, env.Error) Err = env.Error continue Transfer } for _, rr := range env.RR { - switch h := rr.Header().Rrtype; h { - case dns.TypeSOA: - z1.SOA = rr.(*dns.SOA) - case dns.TypeNSEC3, dns.TypeNSEC3PARAM: - err := fmt.Errorf("NSEC3 zone is not supported, dropping") - log.Printf("[ERROR] Failed to parse transfer `%s': %v", z.name, err) + if err := z1.Insert(rr); err != nil { + log.Printf("[ERROR] Failed to parse transfer `%s': %v", z.origin, err) return err - case dns.TypeRRSIG: - if x, ok := rr.(*dns.RRSIG); ok && x.TypeCovered == dns.TypeSOA { - z1.SIG = append(z1.SIG, x) - continue - } - fallthrough - default: - z1.Insert(rr) } } } @@ -62,7 +49,7 @@ Transfer: break } if Err != nil { - log.Printf("[ERROR] Failed to transfer %s: %s", z.name, Err) + log.Printf("[ERROR] Failed to transfer %s: %s", z.origin, Err) return Err } @@ -70,7 +57,7 @@ Transfer: z.SOA = z1.SOA z.SIG = z1.SIG *z.Expired = false - log.Printf("[INFO] Transferred: %s from %s", z.name, tr) + log.Printf("[INFO] Transferred: %s from %s", z.origin, tr) return nil } @@ -80,7 +67,7 @@ func (z *Zone) shouldTransfer() (bool, error) { c := new(dns.Client) c.Net = "tcp" // do this query over TCP to minimize spoofing m := new(dns.Msg) - m.SetQuestion(z.name, dns.TypeSOA) + m.SetQuestion(z.origin, dns.TypeSOA) var Err error serial := -1 diff --git a/middleware/file/secondary_test.go b/middleware/file/secondary_test.go index e86751a8c..ce0f4004f 100644 --- a/middleware/file/secondary_test.go +++ b/middleware/file/secondary_test.go @@ -81,7 +81,7 @@ func TestShouldTransfer(t *testing.T) { defer s.Shutdown() z := new(Zone) - z.name = testZone + z.origin = testZone z.TransferFrom = []string{addrstr} // Serial smaller @@ -118,7 +118,7 @@ func TestTransferIn(t *testing.T) { z := new(Zone) z.Expired = new(bool) - z.name = testZone + z.origin = testZone z.TransferFrom = []string{addrstr} err = z.TransferIn() @@ -133,7 +133,7 @@ func TestTransferIn(t *testing.T) { func TestIsNotify(t *testing.T) { z := new(Zone) z.Expired = new(bool) - z.name = testZone + z.origin = testZone state := NewState(testZone, dns.TypeSOA) // need to set opcode state.Req.Opcode = dns.OpcodeNotify diff --git a/middleware/file/xfr.go b/middleware/file/xfr.go index 228e11c5f..1d87a244b 100644 --- a/middleware/file/xfr.go +++ b/middleware/file/xfr.go @@ -38,7 +38,7 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in j, l := 0, 0 records = append(records, records[0]) // add closing SOA to the end - log.Printf("[INFO] Outgoing transfer of %d records of zone %s to %s started", len(records), x.name, state.IP()) + log.Printf("[INFO] Outgoing transfer of %d records of zone %s to %s started", len(records), x.origin, state.IP()) for i, r := range records { l += dns.Len(r) if l > transferLength { diff --git a/middleware/file/zone.go b/middleware/file/zone.go index f9bb8efe2..40389f5a4 100644 --- a/middleware/file/zone.go +++ b/middleware/file/zone.go @@ -1,36 +1,45 @@ package file import ( + "fmt" + "log" + "os" "sync" "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/file/tree" + "github.com/fsnotify/fsnotify" "github.com/miekg/dns" ) type Zone struct { - SOA *dns.SOA - SIG []dns.RR - name string + SOA *dns.SOA + SIG []dns.RR + origin string + file string *tree.Tree TransferTo []string StartupOnce sync.Once TransferFrom []string Expired *bool + + NoReload bool + reloadMu sync.RWMutex + // TODO: shutdown watcher channel } // NewZone returns a new zone. -func NewZone(name string) *Zone { - z := &Zone{name: dns.Fqdn(name), Tree: &tree.Tree{}, Expired: new(bool)} +func NewZone(name, file string) *Zone { + z := &Zone{origin: dns.Fqdn(name), file: file, Tree: &tree.Tree{}, Expired: new(bool)} *z.Expired = false return z } // Copy copies a zone *without* copying the zone's content. It is not a deep copy. func (z *Zone) Copy() *Zone { - z1 := NewZone(z.name) + z1 := NewZone(z.origin, z.file) z1.TransferTo = z.TransferTo z1.TransferFrom = z.TransferFrom z1.Expired = z.Expired @@ -40,7 +49,24 @@ func (z *Zone) Copy() *Zone { } // Insert inserts r into z. -func (z *Zone) Insert(r dns.RR) { z.Tree.Insert(r) } +func (z *Zone) Insert(r dns.RR) error { + switch h := r.Header().Rrtype; h { + case dns.TypeSOA: + z.SOA = r.(*dns.SOA) + return nil + case dns.TypeNSEC3, dns.TypeNSEC3PARAM: + return fmt.Errorf("NSEC3 zone is not supported, dropping") + case dns.TypeRRSIG: + if x, ok := r.(*dns.RRSIG); ok && x.TypeCovered == dns.TypeSOA { + z.SIG = append(z.SIG, x) + return nil + } + fallthrough + default: + z.Tree.Insert(r) + } + return nil +} // Delete deletes r from z. func (z *Zone) Delete(r dns.RR) { z.Tree.Delete(r) } @@ -59,6 +85,8 @@ func (z *Zone) TransferAllowed(state middleware.State) bool { // All returns all records from the zone, the first record will be the SOA record, // otionally followed by all RRSIG(SOA)s. func (z *Zone) All() []dns.RR { + z.reloadMu.RLock() + defer z.reloadMu.RUnlock() records := []dns.RR{} allNodes := z.Tree.All() for _, a := range allNodes { @@ -70,3 +98,50 @@ func (z *Zone) All() []dns.RR { } return append([]dns.RR{z.SOA}, records...) } + +func (z *Zone) Reload(shutdown chan bool) error { + if z.NoReload { + return nil + } + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + err = watcher.Add(z.file) + if err != nil { + return err + } + + go func() { + // TODO(miek): needs to be killed on reload. + for { + select { + case event := <-watcher.Events: + if event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Rename == fsnotify.Rename { + reader, err := os.Open(z.file) + if err != nil { + log.Printf("[ERROR] Failed to open `%s' for `%s': %v", z.file, z.origin, err) + continue + } + z.reloadMu.Lock() + zone, err := Parse(reader, z.origin, z.file) + if err != nil { + log.Printf("[ERROR] Failed to parse `%s': %v", z.origin, err) + z.reloadMu.Unlock() + continue + } + // copy elements we need + z.SOA = zone.SOA + z.SIG = zone.SIG + z.Tree = zone.Tree + z.reloadMu.Unlock() + log.Printf("[INFO] Successfully reload zone `%s'", z.origin) + } + case <-shutdown: + watcher.Close() + return + } + } + }() + return nil +}