diff --git a/middleware/file/notify.go b/middleware/file/notify.go index b369f6ad1..1a4e43d31 100644 --- a/middleware/file/notify.go +++ b/middleware/file/notify.go @@ -19,7 +19,7 @@ func (z *Zone) isNotify(state middleware.State) bool { if len(z.TransferFrom) == 0 { return false } - remote := middleware.Addr(state.IP()).Normalize() + remote := state.RemoteAddr() for _, from := range z.TransferFrom { if from == remote { return true diff --git a/middleware/file/secondary_test.go b/middleware/file/secondary_test.go index 3533df042..35866335a 100644 --- a/middleware/file/secondary_test.go +++ b/middleware/file/secondary_test.go @@ -1,10 +1,11 @@ package file import ( - "net" - "sync" + "fmt" "testing" - "time" + + "github.com/miekg/coredns/middleware" + coretest "github.com/miekg/coredns/middleware/testing" "github.com/miekg/dns" ) @@ -35,80 +36,56 @@ func TestLess(t *testing.T) { } } -func TCPServer(laddr string) (*dns.Server, string, error) { - l, err := net.Listen("tcp", laddr) - if err != nil { - return nil, "", err - } - - server := &dns.Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour} - - waitLock := sync.Mutex{} - waitLock.Lock() - server.NotifyStartedFunc = waitLock.Unlock - - go func() { - server.ActivateAndServe() - l.Close() - }() - - waitLock.Lock() - return server, l.Addr().String(), nil -} - -func UDPServer(laddr string) (*dns.Server, string, chan bool, error) { - pc, err := net.ListenPacket("udp", laddr) - if err != nil { - return nil, "", nil, err - } - server := &dns.Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour} - - waitLock := sync.Mutex{} - waitLock.Lock() - server.NotifyStartedFunc = waitLock.Unlock - - stop := make(chan bool) - - go func() { - server.ActivateAndServe() - close(stop) - pc.Close() - }() - - waitLock.Lock() - return server, pc.LocalAddr().String(), stop, nil -} - type soa struct { serial uint32 } func (s *soa) Handler(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + switch req.Question[0].Qtype { + case dns.TypeSOA: + m.Answer = make([]dns.RR, 1) + m.Answer[0] = coretest.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, s.serial)) + w.WriteMsg(m) + case dns.TypeAXFR: + m.Answer = make([]dns.RR, 4) + m.Answer[0] = coretest.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, s.serial)) + m.Answer[1] = coretest.A(fmt.Sprintf("%s IN A 127.0.0.1", testZone)) + m.Answer[2] = coretest.A(fmt.Sprintf("%s IN A 127.0.0.1", testZone)) + m.Answer[3] = coretest.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, s.serial)) + w.WriteMsg(m) + } +} + +func (s *soa) TransferHandler(w dns.ResponseWriter, req *dns.Msg) { m := new(dns.Msg) m.SetReply(req) m.Answer = make([]dns.RR, 1) - m.Answer[0] = &dns.SOA{Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 100}, Ns: "bla.", Mbox: "bla.", Serial: s.serial} + m.Answer[0] = coretest.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, s.serial)) w.WriteMsg(m) } +const testZone = "secondary.miek.nl." + func TestShouldTransfer(t *testing.T) { soa := soa{250} - dns.HandleFunc("secondary.miek.nl.", soa.Handler) - defer dns.HandleRemove("secondary.miek.nl.") + dns.HandleFunc(testZone, soa.Handler) + defer dns.HandleRemove(testZone) - s, addrstr, err := TCPServer("127.0.0.1:0") + s, addrstr, err := coretest.TCPServer("127.0.0.1:0") if err != nil { t.Fatalf("unable to run test server: %v", err) } defer s.Shutdown() z := new(Zone) - z.name = "secondary.miek.nl." + z.name = testZone z.TransferFrom = []string{addrstr} // Serial smaller - z.SOA = &dns.SOA{Hdr: dns.RR_Header{Name: "secondary.miek.nl.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 100}, Ns: "bla.", Mbox: "bla.", Serial: soa.serial - 1} + z.SOA = coretest.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, soa.serial-1)) should, err := z.shouldTransfer() if err != nil { t.Fatalf("unable to run shouldTransfer: %v", err) @@ -117,7 +94,7 @@ func TestShouldTransfer(t *testing.T) { t.Fatalf("shouldTransfer should return true for serial: %q", soa.serial-1) } // Serial equal - z.SOA = &dns.SOA{Hdr: dns.RR_Header{Name: "secondary.miek.nl.", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: 100}, Ns: "bla.", Mbox: "bla.", Serial: soa.serial} + z.SOA = coretest.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, soa.serial)) should, err = z.shouldTransfer() if err != nil { t.Fatalf("unable to run shouldTransfer: %v", err) @@ -126,3 +103,54 @@ func TestShouldTransfer(t *testing.T) { t.Fatalf("shouldTransfer should return false for serial: %d", soa.serial) } } + +func TestTransferIn(t *testing.T) { + soa := soa{250} + + dns.HandleFunc(testZone, soa.Handler) + defer dns.HandleRemove(testZone) + + s, addrstr, err := coretest.TCPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + z := new(Zone) + z.Expired = new(bool) + z.name = testZone + z.TransferFrom = []string{addrstr} + + err = z.TransferIn() + if err != nil { + t.Fatalf("unable to run TransferIn: %v", err) + } + if z.SOA.String() != fmt.Sprintf("%s 3600 IN SOA bla. bla. 250 0 0 0 0", testZone) { + t.Fatalf("unknown SOA transferred") + } +} + +func TestIsNotify(t *testing.T) { + z := new(Zone) + z.Expired = new(bool) + z.name = testZone + state := NewState(testZone, dns.TypeSOA) + // need to set opcode + state.Req.Opcode = dns.OpcodeNotify + + z.TransferFrom = []string{"10.240.0.1:40212"} // IP from from testing/responseWriter + if !z.isNotify(state) { + t.Fatal("should have been valid notify") + } + z.TransferFrom = []string{"10.240.0.2:40212"} + if z.isNotify(state) { + t.Fatal("should have been invalid notify") + } +} + +func NewState(zone string, qtype uint16) middleware.State { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.SetEdns0(4097, true) + return middleware.State{W: &coretest.ResponseWriter{}, Req: m} +} diff --git a/middleware/file/zone.go b/middleware/file/zone.go index 57d0cd4a0..c0bd2ffa6 100644 --- a/middleware/file/zone.go +++ b/middleware/file/zone.go @@ -31,6 +31,8 @@ func (z *Zone) Copy() *Zone { z1.TransferTo = z.TransferTo z1.TransferFrom = z.TransferFrom z1.Expired = z.Expired + z1.SOA = z.SOA + z1.SIG = z.SIG return z1 } diff --git a/middleware/state.go b/middleware/state.go index 2a7b2ac1d..fb324d780 100644 --- a/middleware/state.go +++ b/middleware/state.go @@ -50,6 +50,11 @@ func (s *State) Port() (string, error) { return port, nil } +// RemoteAddr returns the net.Addr of the client that sent the current request. +func (s *State) RemoteAddr() string { + return s.W.RemoteAddr().String() +} + // Proto gets the protocol used as the transport. This // will be udp or tcp. func (s *State) Proto() string { diff --git a/middleware/state_test.go b/middleware/state_test.go index 5b9b80f19..fe36480dd 100644 --- a/middleware/state_test.go +++ b/middleware/state_test.go @@ -17,6 +17,20 @@ func TestStateDo(t *testing.T) { } } +func TestStateRemote(t *testing.T) { + st := testState() + if st.IP() != "10.240.0.1" { + t.Fatalf("wrong IP from state") + } + p, err := st.Port() + if err != nil { + t.Fatalf("failed to get Port from state") + } + if p != "40212" { + t.Fatalf("wrong port from state") + } +} + func BenchmarkStateDo(b *testing.B) { st := testState() @@ -37,7 +51,6 @@ func testState() State { m := new(dns.Msg) m.SetQuestion("example.com.", dns.TypeA) m.SetEdns0(4097, true) - return State{W: &coretest.ResponseWriter{}, Req: m} }