diff --git a/plugin/dnstap/dnstapio/io.go b/plugin/dnstap/dnstapio/io.go index 3bc44f85e..08d33aa0c 100644 --- a/plugin/dnstap/dnstapio/io.go +++ b/plugin/dnstap/dnstapio/io.go @@ -13,34 +13,39 @@ import ( const ( tcpTimeout = 4 * time.Second flushTimeout = 1 * time.Second - queueSize = 1000 + queueSize = 10000 ) type dnstapIO struct { - enc *fs.Encoder - conn net.Conn - queue chan tap.Dnstap + endpoint string + socket bool + conn net.Conn + enc *fs.Encoder + queue chan tap.Dnstap } // New returns a new and initialized DnstapIO. -func New() DnstapIO { - return &dnstapIO{queue: make(chan tap.Dnstap, queueSize)} +func New(endpoint string, socket bool) DnstapIO { + return &dnstapIO{ + endpoint: endpoint, + socket: socket, + queue: make(chan tap.Dnstap, queueSize), + } } // DnstapIO interface type DnstapIO interface { - Connect(endpoint string, socket bool) error + Connect() Dnstap(payload tap.Dnstap) Close() } -// Connect connects to the dnstop endpoint. -func (dio *dnstapIO) Connect(endpoint string, socket bool) error { +func (dio *dnstapIO) newConnect() error { var err error - if socket { - dio.conn, err = net.Dial("unix", endpoint) + if dio.socket { + dio.conn, err = net.Dial("unix", dio.endpoint) } else { - dio.conn, err = net.DialTimeout("tcp", endpoint, tcpTimeout) + dio.conn, err = net.DialTimeout("tcp", dio.endpoint, tcpTimeout) } if err != nil { return err @@ -52,10 +57,17 @@ func (dio *dnstapIO) Connect(endpoint string, socket bool) error { if err != nil { return err } - go dio.serve() return nil } +// Connect connects to the dnstop endpoint. +func (dio *dnstapIO) Connect() { + if err := dio.newConnect(); err != nil { + log.Printf("[ERROR] No connection to dnstap endpoint") + } + go dio.serve() +} + // Dnstap enqueues the payload for log. func (dio *dnstapIO) Dnstap(payload tap.Dnstap) { select { @@ -65,36 +77,59 @@ func (dio *dnstapIO) Dnstap(payload tap.Dnstap) { } } +func (dio *dnstapIO) closeConnection() { + dio.enc.Close() + dio.conn.Close() + dio.enc = nil + dio.conn = nil +} + // Close waits until the I/O routine is finished to return. func (dio *dnstapIO) Close() { close(dio.queue) } +func (dio *dnstapIO) write(payload *tap.Dnstap) { + if dio.enc == nil { + if err := dio.newConnect(); err != nil { + return + } + } + var err error + if payload != nil { + frame, e := proto.Marshal(payload) + if err != nil { + log.Printf("[ERROR] Invalid dnstap payload dropped: %s", e) + return + } + _, err = dio.enc.Write(frame) + } else { + err = dio.enc.Flush() + } + if err == nil { + return + } + log.Printf("[WARN] Connection lost: %s", err) + dio.closeConnection() + if err := dio.newConnect(); err != nil { + log.Printf("[ERROR] Cannot write dnstap payload: %s", err) + } else { + log.Printf("[INFO] Reconnect to dnstap done") + } +} + func (dio *dnstapIO) serve() { timeout := time.After(flushTimeout) for { select { case payload, ok := <-dio.queue: if !ok { - dio.enc.Close() - dio.conn.Close() + dio.closeConnection() return } - frame, err := proto.Marshal(&payload) - if err != nil { - log.Printf("[ERROR] Invalid dnstap payload dropped: %s", err) - continue - } - _, err = dio.enc.Write(frame) - if err != nil { - log.Printf("[ERROR] Cannot write dnstap payload: %s", err) - continue - } + dio.write(&payload) case <-timeout: - err := dio.enc.Flush() - if err != nil { - log.Printf("[ERROR] Cannot flush dnstap payloads: %s", err) - } + dio.write(nil) timeout = time.After(flushTimeout) } } diff --git a/plugin/dnstap/dnstapio/io_test.go b/plugin/dnstap/dnstapio/io_test.go index bfeeb4289..c74ac6f73 100644 --- a/plugin/dnstap/dnstapio/io_test.go +++ b/plugin/dnstap/dnstapio/io_test.go @@ -10,6 +10,16 @@ import ( fs "github.com/farsightsec/golang-framestream" ) +const ( + endpointTCP = "localhost:0" + endpointSocket = "dnstap.sock" +) + +var ( + msgType = tap.Dnstap_MESSAGE + msg = tap.Dnstap{Type: &msgType} +) + func accept(t *testing.T, l net.Listener, count int) { server, err := l.Accept() if err != nil { @@ -37,88 +47,39 @@ func accept(t *testing.T, l net.Listener, count int) { } } -const endpointTCP = "localhost:0" - -func TestTCP(t *testing.T) { - dio := New() - - err := dio.Connect(endpointTCP, false) - if err == nil { - t.Fatal("Not listening but no error") +func TestTransport(t *testing.T) { + transport := [2][3]string{ + {"tcp", endpointTCP, "false"}, + {"unix", endpointSocket, "true"}, } - // Start TCP listener - l, err := net.Listen("tcp", endpointTCP) - if err != nil { - t.Fatalf("Cannot start listener: %s", err) + for _, param := range transport { + // Start TCP listener + l, err := net.Listen(param[0], param[1]) + if err != nil { + t.Fatalf("Cannot start listener: %s", err) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + accept(t, l, 1) + wg.Done() + }() + + dio := New(l.Addr().String(), param[2] == "true") + dio.Connect() + + dio.Dnstap(msg) + + wg.Wait() + l.Close() + dio.Close() } - defer l.Close() - - var wg sync.WaitGroup - wg.Add(1) - go func() { - accept(t, l, 1) - wg.Done() - }() - - err = dio.Connect(l.Addr().String(), false) - if err != nil { - t.Fatalf("Cannot connect to listener: %s", err) - } - - msg := tap.Dnstap_MESSAGE - dio.Dnstap(tap.Dnstap{Type: &msg}) - - wg.Wait() - - dio.Close() -} - -const endpointSocket = "dnstap.sock" - -func TestSocket(t *testing.T) { - dio := New() - - err := dio.Connect(endpointSocket, true) - if err == nil { - t.Fatal("Not listening but no error") - } - - // Start Socket listener - l, err := net.Listen("unix", endpointSocket) - if err != nil { - t.Fatalf("Cannot start listener: %s", err) - } - defer l.Close() - - var wg sync.WaitGroup - wg.Add(1) - go func() { - accept(t, l, 1) - wg.Done() - }() - - err = dio.Connect(endpointSocket, true) - if err != nil { - t.Fatalf("Cannot connect to listener: %s", err) - } - - msg := tap.Dnstap_MESSAGE - dio.Dnstap(tap.Dnstap{Type: &msg}) - - wg.Wait() - - dio.Close() } func TestRace(t *testing.T) { count := 10 - dio := New() - - err := dio.Connect(endpointTCP, false) - if err == nil { - t.Fatal("Not listening but no error") - } // Start TCP listener l, err := net.Listen("tcp", endpointTCP) @@ -134,22 +95,68 @@ func TestRace(t *testing.T) { wg.Done() }() - err = dio.Connect(l.Addr().String(), false) - if err != nil { - t.Fatalf("Cannot connect to listener: %s", err) - } + dio := New(l.Addr().String(), false) + dio.Connect() + defer dio.Close() - msg := tap.Dnstap_MESSAGE wg.Add(count) for i := 0; i < count; i++ { - go func(i byte) { + go func() { time.Sleep(50 * time.Millisecond) - dio.Dnstap(tap.Dnstap{Type: &msg, Extra: []byte{i}}) + dio.Dnstap(msg) wg.Done() - }(byte(i)) + }() + } + + wg.Wait() +} + +func TestReconnect(t *testing.T) { + count := 5 + + // Start TCP listener + l, err := net.Listen("tcp", endpointTCP) + if err != nil { + t.Fatalf("Cannot start listener: %s", err) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + accept(t, l, 1) + wg.Done() + }() + + addr := l.Addr().String() + dio := New(addr, false) + dio.Connect() + defer dio.Close() + + msg := tap.Dnstap_MESSAGE + dio.Dnstap(tap.Dnstap{Type: &msg}) + + wg.Wait() + + // Close listener + l.Close() + + // And start TCP listener again on the same port + l, err = net.Listen("tcp", addr) + if err != nil { + t.Fatalf("Cannot start listener: %s", err) + } + defer l.Close() + + wg.Add(1) + go func() { + accept(t, l, 1) + wg.Done() + }() + + for i := 0; i < count; i++ { + time.Sleep(time.Second) + dio.Dnstap(tap.Dnstap{Type: &msg}) } wg.Wait() - - dio.Close() } diff --git a/plugin/dnstap/setup.go b/plugin/dnstap/setup.go index 342f14e88..4c6ae1d4f 100644 --- a/plugin/dnstap/setup.go +++ b/plugin/dnstap/setup.go @@ -65,14 +65,11 @@ func setup(c *caddy.Controller) error { return err } - dio := dnstapio.New() + dio := dnstapio.New(conf.target, conf.socket) dnstap := Dnstap{IO: dio, Pack: conf.full} c.OnStartup(func() error { - err := dio.Connect(conf.target, conf.socket) - if err != nil { - return plugin.Error("dnstap", err) - } + dio.Connect() return nil })