Added reconnect feature for dnstap plugin (#1267)

This commit is contained in:
Uladzimir Trehubenka 2017-12-01 13:16:14 +02:00 committed by Miek Gieben
parent 917965fa86
commit 861e2382c2
3 changed files with 158 additions and 119 deletions

View file

@ -13,34 +13,39 @@ import (
const ( const (
tcpTimeout = 4 * time.Second tcpTimeout = 4 * time.Second
flushTimeout = 1 * time.Second flushTimeout = 1 * time.Second
queueSize = 1000 queueSize = 10000
) )
type dnstapIO struct { type dnstapIO struct {
enc *fs.Encoder endpoint string
conn net.Conn socket bool
queue chan tap.Dnstap conn net.Conn
enc *fs.Encoder
queue chan tap.Dnstap
} }
// New returns a new and initialized DnstapIO. // New returns a new and initialized DnstapIO.
func New() DnstapIO { func New(endpoint string, socket bool) DnstapIO {
return &dnstapIO{queue: make(chan tap.Dnstap, queueSize)} return &dnstapIO{
endpoint: endpoint,
socket: socket,
queue: make(chan tap.Dnstap, queueSize),
}
} }
// DnstapIO interface // DnstapIO interface
type DnstapIO interface { type DnstapIO interface {
Connect(endpoint string, socket bool) error Connect()
Dnstap(payload tap.Dnstap) Dnstap(payload tap.Dnstap)
Close() Close()
} }
// Connect connects to the dnstop endpoint. func (dio *dnstapIO) newConnect() error {
func (dio *dnstapIO) Connect(endpoint string, socket bool) error {
var err error var err error
if socket { if dio.socket {
dio.conn, err = net.Dial("unix", endpoint) dio.conn, err = net.Dial("unix", dio.endpoint)
} else { } else {
dio.conn, err = net.DialTimeout("tcp", endpoint, tcpTimeout) dio.conn, err = net.DialTimeout("tcp", dio.endpoint, tcpTimeout)
} }
if err != nil { if err != nil {
return err return err
@ -52,10 +57,17 @@ func (dio *dnstapIO) Connect(endpoint string, socket bool) error {
if err != nil { if err != nil {
return err return err
} }
go dio.serve()
return nil return nil
} }
// Connect connects to the dnstop endpoint.
func (dio *dnstapIO) Connect() {
if err := dio.newConnect(); err != nil {
log.Printf("[ERROR] No connection to dnstap endpoint")
}
go dio.serve()
}
// Dnstap enqueues the payload for log. // Dnstap enqueues the payload for log.
func (dio *dnstapIO) Dnstap(payload tap.Dnstap) { func (dio *dnstapIO) Dnstap(payload tap.Dnstap) {
select { select {
@ -65,36 +77,59 @@ func (dio *dnstapIO) Dnstap(payload tap.Dnstap) {
} }
} }
func (dio *dnstapIO) closeConnection() {
dio.enc.Close()
dio.conn.Close()
dio.enc = nil
dio.conn = nil
}
// Close waits until the I/O routine is finished to return. // Close waits until the I/O routine is finished to return.
func (dio *dnstapIO) Close() { func (dio *dnstapIO) Close() {
close(dio.queue) close(dio.queue)
} }
func (dio *dnstapIO) write(payload *tap.Dnstap) {
if dio.enc == nil {
if err := dio.newConnect(); err != nil {
return
}
}
var err error
if payload != nil {
frame, e := proto.Marshal(payload)
if err != nil {
log.Printf("[ERROR] Invalid dnstap payload dropped: %s", e)
return
}
_, err = dio.enc.Write(frame)
} else {
err = dio.enc.Flush()
}
if err == nil {
return
}
log.Printf("[WARN] Connection lost: %s", err)
dio.closeConnection()
if err := dio.newConnect(); err != nil {
log.Printf("[ERROR] Cannot write dnstap payload: %s", err)
} else {
log.Printf("[INFO] Reconnect to dnstap done")
}
}
func (dio *dnstapIO) serve() { func (dio *dnstapIO) serve() {
timeout := time.After(flushTimeout) timeout := time.After(flushTimeout)
for { for {
select { select {
case payload, ok := <-dio.queue: case payload, ok := <-dio.queue:
if !ok { if !ok {
dio.enc.Close() dio.closeConnection()
dio.conn.Close()
return return
} }
frame, err := proto.Marshal(&payload) dio.write(&payload)
if err != nil {
log.Printf("[ERROR] Invalid dnstap payload dropped: %s", err)
continue
}
_, err = dio.enc.Write(frame)
if err != nil {
log.Printf("[ERROR] Cannot write dnstap payload: %s", err)
continue
}
case <-timeout: case <-timeout:
err := dio.enc.Flush() dio.write(nil)
if err != nil {
log.Printf("[ERROR] Cannot flush dnstap payloads: %s", err)
}
timeout = time.After(flushTimeout) timeout = time.After(flushTimeout)
} }
} }

View file

@ -10,6 +10,16 @@ import (
fs "github.com/farsightsec/golang-framestream" fs "github.com/farsightsec/golang-framestream"
) )
const (
endpointTCP = "localhost:0"
endpointSocket = "dnstap.sock"
)
var (
msgType = tap.Dnstap_MESSAGE
msg = tap.Dnstap{Type: &msgType}
)
func accept(t *testing.T, l net.Listener, count int) { func accept(t *testing.T, l net.Listener, count int) {
server, err := l.Accept() server, err := l.Accept()
if err != nil { if err != nil {
@ -37,88 +47,39 @@ func accept(t *testing.T, l net.Listener, count int) {
} }
} }
const endpointTCP = "localhost:0" func TestTransport(t *testing.T) {
transport := [2][3]string{
func TestTCP(t *testing.T) { {"tcp", endpointTCP, "false"},
dio := New() {"unix", endpointSocket, "true"},
err := dio.Connect(endpointTCP, false)
if err == nil {
t.Fatal("Not listening but no error")
} }
// Start TCP listener for _, param := range transport {
l, err := net.Listen("tcp", endpointTCP) // Start TCP listener
if err != nil { l, err := net.Listen(param[0], param[1])
t.Fatalf("Cannot start listener: %s", err) if err != nil {
t.Fatalf("Cannot start listener: %s", err)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
accept(t, l, 1)
wg.Done()
}()
dio := New(l.Addr().String(), param[2] == "true")
dio.Connect()
dio.Dnstap(msg)
wg.Wait()
l.Close()
dio.Close()
} }
defer l.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
accept(t, l, 1)
wg.Done()
}()
err = dio.Connect(l.Addr().String(), false)
if err != nil {
t.Fatalf("Cannot connect to listener: %s", err)
}
msg := tap.Dnstap_MESSAGE
dio.Dnstap(tap.Dnstap{Type: &msg})
wg.Wait()
dio.Close()
}
const endpointSocket = "dnstap.sock"
func TestSocket(t *testing.T) {
dio := New()
err := dio.Connect(endpointSocket, true)
if err == nil {
t.Fatal("Not listening but no error")
}
// Start Socket listener
l, err := net.Listen("unix", endpointSocket)
if err != nil {
t.Fatalf("Cannot start listener: %s", err)
}
defer l.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
accept(t, l, 1)
wg.Done()
}()
err = dio.Connect(endpointSocket, true)
if err != nil {
t.Fatalf("Cannot connect to listener: %s", err)
}
msg := tap.Dnstap_MESSAGE
dio.Dnstap(tap.Dnstap{Type: &msg})
wg.Wait()
dio.Close()
} }
func TestRace(t *testing.T) { func TestRace(t *testing.T) {
count := 10 count := 10
dio := New()
err := dio.Connect(endpointTCP, false)
if err == nil {
t.Fatal("Not listening but no error")
}
// Start TCP listener // Start TCP listener
l, err := net.Listen("tcp", endpointTCP) l, err := net.Listen("tcp", endpointTCP)
@ -134,22 +95,68 @@ func TestRace(t *testing.T) {
wg.Done() wg.Done()
}() }()
err = dio.Connect(l.Addr().String(), false) dio := New(l.Addr().String(), false)
if err != nil { dio.Connect()
t.Fatalf("Cannot connect to listener: %s", err) defer dio.Close()
}
msg := tap.Dnstap_MESSAGE
wg.Add(count) wg.Add(count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
go func(i byte) { go func() {
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
dio.Dnstap(tap.Dnstap{Type: &msg, Extra: []byte{i}}) dio.Dnstap(msg)
wg.Done() wg.Done()
}(byte(i)) }()
}
wg.Wait()
}
func TestReconnect(t *testing.T) {
count := 5
// Start TCP listener
l, err := net.Listen("tcp", endpointTCP)
if err != nil {
t.Fatalf("Cannot start listener: %s", err)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
accept(t, l, 1)
wg.Done()
}()
addr := l.Addr().String()
dio := New(addr, false)
dio.Connect()
defer dio.Close()
msg := tap.Dnstap_MESSAGE
dio.Dnstap(tap.Dnstap{Type: &msg})
wg.Wait()
// Close listener
l.Close()
// And start TCP listener again on the same port
l, err = net.Listen("tcp", addr)
if err != nil {
t.Fatalf("Cannot start listener: %s", err)
}
defer l.Close()
wg.Add(1)
go func() {
accept(t, l, 1)
wg.Done()
}()
for i := 0; i < count; i++ {
time.Sleep(time.Second)
dio.Dnstap(tap.Dnstap{Type: &msg})
} }
wg.Wait() wg.Wait()
dio.Close()
} }

View file

@ -65,14 +65,11 @@ func setup(c *caddy.Controller) error {
return err return err
} }
dio := dnstapio.New() dio := dnstapio.New(conf.target, conf.socket)
dnstap := Dnstap{IO: dio, Pack: conf.full} dnstap := Dnstap{IO: dio, Pack: conf.full}
c.OnStartup(func() error { c.OnStartup(func() error {
err := dio.Connect(conf.target, conf.socket) dio.Connect()
if err != nil {
return plugin.Error("dnstap", err)
}
return nil return nil
}) })