diff --git a/plugin/auto/walk.go b/plugin/auto/walk.go index 351210509..9b60cb9be 100644 --- a/plugin/auto/walk.go +++ b/plugin/auto/walk.go @@ -31,9 +31,10 @@ func (a Auto) Walk() error { return nil } - if _, ok := a.Zones.Z[origin]; ok { + if z, ok := a.Zones.Z[origin]; ok { // we already have this zone toDelete[origin] = false + z.SetFile(path) return nil } diff --git a/plugin/auto/watcher_test.go b/plugin/auto/watcher_test.go index dde7053fd..1d707d747 100644 --- a/plugin/auto/watcher_test.go +++ b/plugin/auto/watcher_test.go @@ -51,4 +51,51 @@ func TestWatcher(t *testing.T) { if _, ok := a.Zones.Z["example.org."]; !ok { t.Errorf("Expected %q to still be there.", "example.org.") } + +} + +func TestSymlinks(t *testing.T) { + tempdir, err := createFiles() + if err != nil { + if tempdir != "" { + os.RemoveAll(tempdir) + } + t.Fatal(err) + } + defer os.RemoveAll(tempdir) + + ldr := loader{ + directory: tempdir, + re: regexp.MustCompile(`db\.(.*)`), + template: `${1}`, + } + + a := Auto{ + loader: ldr, + Zones: &Zones{}, + } + + a.Walk() + + // Now create a duplicate file in a subdirectory and repoint the symlink + if err := os.Remove(path.Join(tempdir, "db.example.com")); err != nil { + t.Fatal(err) + } + dataDir := path.Join(tempdir, "..data") + if err = os.Mkdir(dataDir, 0755); err != nil { + t.Fatal(err) + } + newFile := path.Join(dataDir, "db.example.com") + if err = os.Symlink(path.Join(tempdir, "db.example.org"), newFile); err != nil { + t.Fatal(err) + } + + a.Walk() + + if storedZone, ok := a.Zones.Z["example.com."]; ok { + storedFile := storedZone.File() + if storedFile != newFile { + t.Errorf("Expected %q to reflect new path %q", storedFile, newFile) + } + } } diff --git a/plugin/file/reload.go b/plugin/file/reload.go index af06b98ac..7c0fe8544 100644 --- a/plugin/file/reload.go +++ b/plugin/file/reload.go @@ -22,14 +22,15 @@ func (z *Zone) Reload() error { select { case <-tick.C: - reader, err := os.Open(z.file) + zFile := z.File() + reader, err := os.Open(zFile) if err != nil { - log.Errorf("Failed to open zone %q in %q: %v", z.origin, z.file, err) + log.Errorf("Failed to open zone %q in %q: %v", z.origin, zFile, err) continue } serial := z.SOASerialIfDefined() - zone, err := Parse(reader, z.origin, z.file, serial) + zone, err := Parse(reader, z.origin, zFile, serial) if err != nil { if _, ok := err.(*serialErr); !ok { log.Errorf("Parsing zone %q: %v", z.origin, err) @@ -43,7 +44,7 @@ func (z *Zone) Reload() error { z.Tree = zone.Tree z.reloadMu.Unlock() - log.Infof("Successfully reloaded zone %q in %q with serial %d", z.origin, z.file, z.Apex.SOA.Serial) + log.Infof("Successfully reloaded zone %q in %q with serial %d", z.origin, zFile, z.Apex.SOA.Serial) z.Notify() case <-z.reloadShutdown: diff --git a/plugin/file/zone.go b/plugin/file/zone.go index da294ed45..6e1ec6d69 100644 --- a/plugin/file/zone.go +++ b/plugin/file/zone.go @@ -124,6 +124,20 @@ func (z *Zone) Insert(r dns.RR) error { // Delete deletes r from z. func (z *Zone) Delete(r dns.RR) { z.Tree.Delete(r) } +// File retrieves the file path in a safe way +func (z *Zone) File() string { + z.reloadMu.Lock() + defer z.reloadMu.Unlock() + return z.file +} + +// SetFile updates the file path in a safe way +func (z *Zone) SetFile(path string) { + z.reloadMu.Lock() + z.file = path + z.reloadMu.Unlock() +} + // TransferAllowed checks if incoming request for transferring the zone is allowed according to the ACLs. func (z *Zone) TransferAllowed(state request.Request) bool { for _, t := range z.TransferTo {