Reload zone when a write is detected (#122)

Zone reloading
This commit is contained in:
Miek Gieben 2016-04-15 14:26:27 +01:00
parent 29ad957a9d
commit c9d8a57ed6
10 changed files with 179 additions and 60 deletions

View file

@ -16,18 +16,17 @@ func File(c *Controller) (middleware.Middleware, error) {
return nil, err
}
// Add startup functions to notify the master.
// Add startup functions to notify the master(s).
for _, n := range zones.Names {
if len(zones.Z[n].TransferTo) > 0 {
c.Startup = append(c.Startup, func() error {
zones.Z[n].StartupOnce.Do(func() {
if len(zones.Z[n].TransferTo) > 0 {
zones.Z[n].Notify()
}
})
return nil
c.Startup = append(c.Startup, func() error {
zones.Z[n].StartupOnce.Do(func() {
if len(zones.Z[n].TransferTo) > 0 {
zones.Z[n].Notify()
}
zones.Z[n].Reload(nil)
})
}
return nil
})
}
return func(next middleware.Handler) middleware.Handler {
@ -67,17 +66,24 @@ func fileParse(c *Controller) (file.Zones, error) {
names = append(names, origins[i])
}
noReload := false
for c.NextBlock() {
t, _, e := parseTransfer(c)
if e != nil {
return file.Zones{}, e
}
switch c.Val() {
case "no_reload":
noReload = true
}
// discard from, here, maybe check and show log when we do?
for _, origin := range origins {
if t != nil {
z[origin].TransferTo = append(z[origin].TransferTo, t...)
}
z[origin].NoReload = noReload
}
}
}
}

View file

@ -47,7 +47,7 @@ func secondaryParse(c *Controller) (file.Zones, error) {
}
for i, _ := range origins {
origins[i] = middleware.Host(origins[i]).Normalize()
z[origins[i]] = file.NewZone(origins[i])
z[origins[i]] = file.NewZone(origins[i], "stdin")
names = append(names, origins[i])
}

View file

@ -24,13 +24,15 @@ TSIG key information, something like `transfer out [address...] key [name] [base
file dbfile [zones... ] {
transfer from [address...]
transfer to [address...]
no_reload
}
~~~
* `transfer` enables zone transfers. It may be specified multiples times. *To* or *from* signals
the direction. Addresses must be denoted in CIDR notation (127.0.0.1/32 etc.) or just as plain
address. The special wildcard "*" means: the entire internet (only valid for 'transfer to').
* `no_reload` by default CoreDNS will reload a zone from disk whenever it detects a change to the
file. This option disables that behavior.
## Examples

View file

@ -105,27 +105,14 @@ func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i
// Parse parses the zone in filename and returns a new Zone or an error.
func Parse(f io.Reader, origin, fileName string) (*Zone, error) {
tokens := dns.ParseZone(f, dns.Fqdn(origin), fileName)
z := NewZone(origin)
z := NewZone(origin, fileName)
for x := range tokens {
if x.Error != nil {
log.Printf("[ERROR] Failed to parse `%s': %v", origin, x.Error)
return nil, x.Error
}
switch h := x.RR.Header().Rrtype; h {
case dns.TypeSOA:
z.SOA = x.RR.(*dns.SOA)
case dns.TypeNSEC3, dns.TypeNSEC3PARAM:
err := fmt.Errorf("NSEC3 zone is not supported, dropping")
log.Printf("[ERROR] Failed to parse `%s': %v", origin, err)
if err := z.Insert(x.RR); err != nil {
return nil, err
case dns.TypeRRSIG:
if x, ok := x.RR.(*dns.RRSIG); ok && x.TypeCovered == dns.TypeSOA {
z.SIG = append(z.SIG, x)
continue
}
fallthrough
default:
z.Insert(x.RR)
}
}
return z, nil

View file

@ -30,7 +30,7 @@ func (z *Zone) isNotify(state middleware.State) bool {
// Notify will send notifies to all configured TransferTo IP addresses.
func (z *Zone) Notify() {
go notify(z.name, z.TransferTo)
go notify(z.origin, z.TransferTo)
}
// notify sends notifies to the configured remote servers. It will try up to three times

View file

@ -0,0 +1,62 @@
package file
import (
"io/ioutil"
"os"
"testing"
"time"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
)
func TestZoneReload(t *testing.T) {
fileName, rm, err := test.Zone(t, ".", reloadZoneTest)
if err != nil {
t.Fatalf("failed to create zone: %s", err)
}
defer rm()
reader, err := os.Open(fileName)
if err != nil {
t.Fatalf("failed to open zone: %s", err)
}
z, err := Parse(reader, "miek.nl", fileName)
if err != nil {
t.Fatalf("failed to parse zone: %s", err)
}
z.Reload(nil)
if _, _, _, res := z.Lookup("miek.nl.", dns.TypeSOA, false); res != Success {
t.Fatalf("failed to lookup, got %d", res)
}
if _, _, _, res := z.Lookup("miek.nl.", dns.TypeNS, false); res != Success {
t.Fatalf("failed to lookup, got %d", res)
}
if len(z.All()) != 5 {
t.Fatalf("expected 5 RRs, got %d", len(z.All()))
}
if err := ioutil.WriteFile(fileName, []byte(reloadZone2Test), 0644); err != nil {
t.Fatalf("failed to write new zone data", err)
}
// Could still be racy, but we need to wait a bit for the event to be seen
time.Sleep(1 * time.Second)
if len(z.All()) != 3 {
t.Fatalf("expected 3 RRs, got %d", len(z.All()))
}
}
const reloadZoneTest = `miek.nl. 1627 IN SOA linode.atoom.net. miek.miek.nl. 1460175181 14400 3600 604800 14400
miek.nl. 1627 IN NS ext.ns.whyscream.net.
miek.nl. 1627 IN NS omval.tednet.nl.
miek.nl. 1627 IN NS linode.atoom.net.
miek.nl. 1627 IN NS ns-ext.nlnetlabs.nl.
`
const reloadZone2Test = `miek.nl. 1627 IN SOA linode.atoom.net. miek.miek.nl. 1460175181 14400 3600 604800 14400
miek.nl. 1627 IN NS ext.ns.whyscream.net.
miek.nl. 1627 IN NS omval.tednet.nl.
`

View file

@ -1,7 +1,6 @@
package file
import (
"fmt"
"log"
"time"
@ -16,7 +15,7 @@ func (z *Zone) TransferIn() error {
return nil
}
m := new(dns.Msg)
m.SetAxfr(z.name)
m.SetAxfr(z.origin)
z1 := z.Copy()
var (
@ -29,32 +28,20 @@ Transfer:
t := new(dns.Transfer)
c, err := t.In(m, tr)
if err != nil {
log.Printf("[ERROR] Failed to setup transfer `%s' with `%s': %v", z.name, tr, err)
log.Printf("[ERROR] Failed to setup transfer `%s' with `%s': %v", z.origin, tr, err)
Err = err
continue Transfer
}
for env := range c {
if env.Error != nil {
log.Printf("[ERROR] Failed to parse transfer `%s': %v", z.name, env.Error)
log.Printf("[ERROR] Failed to parse transfer `%s': %v", z.origin, env.Error)
Err = env.Error
continue Transfer
}
for _, rr := range env.RR {
switch h := rr.Header().Rrtype; h {
case dns.TypeSOA:
z1.SOA = rr.(*dns.SOA)
case dns.TypeNSEC3, dns.TypeNSEC3PARAM:
err := fmt.Errorf("NSEC3 zone is not supported, dropping")
log.Printf("[ERROR] Failed to parse transfer `%s': %v", z.name, err)
if err := z1.Insert(rr); err != nil {
log.Printf("[ERROR] Failed to parse transfer `%s': %v", z.origin, err)
return err
case dns.TypeRRSIG:
if x, ok := rr.(*dns.RRSIG); ok && x.TypeCovered == dns.TypeSOA {
z1.SIG = append(z1.SIG, x)
continue
}
fallthrough
default:
z1.Insert(rr)
}
}
}
@ -62,7 +49,7 @@ Transfer:
break
}
if Err != nil {
log.Printf("[ERROR] Failed to transfer %s: %s", z.name, Err)
log.Printf("[ERROR] Failed to transfer %s: %s", z.origin, Err)
return Err
}
@ -70,7 +57,7 @@ Transfer:
z.SOA = z1.SOA
z.SIG = z1.SIG
*z.Expired = false
log.Printf("[INFO] Transferred: %s from %s", z.name, tr)
log.Printf("[INFO] Transferred: %s from %s", z.origin, tr)
return nil
}
@ -80,7 +67,7 @@ func (z *Zone) shouldTransfer() (bool, error) {
c := new(dns.Client)
c.Net = "tcp" // do this query over TCP to minimize spoofing
m := new(dns.Msg)
m.SetQuestion(z.name, dns.TypeSOA)
m.SetQuestion(z.origin, dns.TypeSOA)
var Err error
serial := -1

View file

@ -81,7 +81,7 @@ func TestShouldTransfer(t *testing.T) {
defer s.Shutdown()
z := new(Zone)
z.name = testZone
z.origin = testZone
z.TransferFrom = []string{addrstr}
// Serial smaller
@ -118,7 +118,7 @@ func TestTransferIn(t *testing.T) {
z := new(Zone)
z.Expired = new(bool)
z.name = testZone
z.origin = testZone
z.TransferFrom = []string{addrstr}
err = z.TransferIn()
@ -133,7 +133,7 @@ func TestTransferIn(t *testing.T) {
func TestIsNotify(t *testing.T) {
z := new(Zone)
z.Expired = new(bool)
z.name = testZone
z.origin = testZone
state := NewState(testZone, dns.TypeSOA)
// need to set opcode
state.Req.Opcode = dns.OpcodeNotify

View file

@ -38,7 +38,7 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in
j, l := 0, 0
records = append(records, records[0]) // add closing SOA to the end
log.Printf("[INFO] Outgoing transfer of %d records of zone %s to %s started", len(records), x.name, state.IP())
log.Printf("[INFO] Outgoing transfer of %d records of zone %s to %s started", len(records), x.origin, state.IP())
for i, r := range records {
l += dns.Len(r)
if l > transferLength {

View file

@ -1,36 +1,45 @@
package file
import (
"fmt"
"log"
"os"
"sync"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/file/tree"
"github.com/fsnotify/fsnotify"
"github.com/miekg/dns"
)
type Zone struct {
SOA *dns.SOA
SIG []dns.RR
name string
SOA *dns.SOA
SIG []dns.RR
origin string
file string
*tree.Tree
TransferTo []string
StartupOnce sync.Once
TransferFrom []string
Expired *bool
NoReload bool
reloadMu sync.RWMutex
// TODO: shutdown watcher channel
}
// NewZone returns a new zone.
func NewZone(name string) *Zone {
z := &Zone{name: dns.Fqdn(name), Tree: &tree.Tree{}, Expired: new(bool)}
func NewZone(name, file string) *Zone {
z := &Zone{origin: dns.Fqdn(name), file: file, Tree: &tree.Tree{}, Expired: new(bool)}
*z.Expired = false
return z
}
// Copy copies a zone *without* copying the zone's content. It is not a deep copy.
func (z *Zone) Copy() *Zone {
z1 := NewZone(z.name)
z1 := NewZone(z.origin, z.file)
z1.TransferTo = z.TransferTo
z1.TransferFrom = z.TransferFrom
z1.Expired = z.Expired
@ -40,7 +49,24 @@ func (z *Zone) Copy() *Zone {
}
// Insert inserts r into z.
func (z *Zone) Insert(r dns.RR) { z.Tree.Insert(r) }
func (z *Zone) Insert(r dns.RR) error {
switch h := r.Header().Rrtype; h {
case dns.TypeSOA:
z.SOA = r.(*dns.SOA)
return nil
case dns.TypeNSEC3, dns.TypeNSEC3PARAM:
return fmt.Errorf("NSEC3 zone is not supported, dropping")
case dns.TypeRRSIG:
if x, ok := r.(*dns.RRSIG); ok && x.TypeCovered == dns.TypeSOA {
z.SIG = append(z.SIG, x)
return nil
}
fallthrough
default:
z.Tree.Insert(r)
}
return nil
}
// Delete deletes r from z.
func (z *Zone) Delete(r dns.RR) { z.Tree.Delete(r) }
@ -59,6 +85,8 @@ func (z *Zone) TransferAllowed(state middleware.State) bool {
// All returns all records from the zone, the first record will be the SOA record,
// otionally followed by all RRSIG(SOA)s.
func (z *Zone) All() []dns.RR {
z.reloadMu.RLock()
defer z.reloadMu.RUnlock()
records := []dns.RR{}
allNodes := z.Tree.All()
for _, a := range allNodes {
@ -70,3 +98,50 @@ func (z *Zone) All() []dns.RR {
}
return append([]dns.RR{z.SOA}, records...)
}
func (z *Zone) Reload(shutdown chan bool) error {
if z.NoReload {
return nil
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
err = watcher.Add(z.file)
if err != nil {
return err
}
go func() {
// TODO(miek): needs to be killed on reload.
for {
select {
case event := <-watcher.Events:
if event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Rename == fsnotify.Rename {
reader, err := os.Open(z.file)
if err != nil {
log.Printf("[ERROR] Failed to open `%s' for `%s': %v", z.file, z.origin, err)
continue
}
z.reloadMu.Lock()
zone, err := Parse(reader, z.origin, z.file)
if err != nil {
log.Printf("[ERROR] Failed to parse `%s': %v", z.origin, err)
z.reloadMu.Unlock()
continue
}
// copy elements we need
z.SOA = zone.SOA
z.SIG = zone.SIG
z.Tree = zone.Tree
z.reloadMu.Unlock()
log.Printf("[INFO] Successfully reload zone `%s'", z.origin)
}
case <-shutdown:
watcher.Close()
return
}
}
}()
return nil
}