diff --git a/plugin/dnstap/README.md b/plugin/dnstap/README.md index 67345284f..b2bbd3c0f 100644 --- a/plugin/dnstap/README.md +++ b/plugin/dnstap/README.md @@ -66,28 +66,39 @@ $ dnstap -l 127.0.0.1:6000 ## Using Dnstap in your plugin -~~~ Go -import ( - "github.com/coredns/coredns/plugin/dnstap" - "github.com/coredns/coredns/plugin/dnstap/msg" -) +In your setup function, check to see if the *dnstap* plugin is loaded: -func (h Dnstap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - // log client query to Dnstap - if t := dnstap.TapperFromContext(ctx); t != nil { - b := msg.New().Time(time.Now()).Addr(w.RemoteAddr()) - if t.Pack() { - b.Msg(r) - } - if m, err := b.ToClientQuery(); err == nil { - t.TapMessage(m) +~~~ go +c.OnStartup(func() error { + if taph := dnsserver.GetConfig(c).Handler("dnstap"); taph != nil { + if tapPlugin, ok := taph.(dnstap.Dnstap); ok { + f.tapPlugin = &tapPlugin } } + return nil +}) +~~~ +And then in your plugin: + +~~~ go +func (x RandomPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if tapPlugin != nil { + q := new(msg.Msg) + msg.SetQueryTime(q, time.Now()) + msg.SetQueryAddress(q, w.RemoteAddr()) + if tapPlugin.IncludeRawMessage { + buf, _ := r.Pack() // r has been seen packed/unpacked before, this should not fail + q.QueryMessage = buf + } + msg.SetType(q, tap.Message_CLIENT_QUERY) + tapPlugin.TapMessage(q) + } // ... } ~~~ ## See Also -[dnstap.info](https://dnstap.info). +The website [dnstap.info](https://dnstap.info) has info on the dnstap protocol. +The *forward* plugin's `dnstap.go` uses dnstap to tap messages sent to an upstream. diff --git a/plugin/dnstap/context_test.go b/plugin/dnstap/context_test.go deleted file mode 100644 index 64418f59b..000000000 --- a/plugin/dnstap/context_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package dnstap - -import ( - "context" - "testing" -) - -func TestDnstapContext(t *testing.T) { - ctx := ContextWithTapper(context.TODO(), Dnstap{}) - tapper := TapperFromContext(ctx) - - if tapper == nil { - t.Fatal("Can't get tapper") - } -} diff --git a/plugin/dnstap/dnstapio/io.go b/plugin/dnstap/dnstapio/io.go index 9a4c26042..c88fc14ab 100644 --- a/plugin/dnstap/dnstapio/io.go +++ b/plugin/dnstap/dnstapio/io.go @@ -20,7 +20,13 @@ const ( queueSize = 10000 ) -type dnstapIO struct { +// Tapper interface is used in testing to mock the Dnstap method. +type Tapper interface { + Dnstap(tap.Dnstap) +} + +// dio implements the Tapper interface. +type dio struct { endpoint string socket bool conn net.Conn @@ -30,9 +36,9 @@ type dnstapIO struct { quit chan struct{} } -// New returns a new and initialized DnstapIO. -func New(endpoint string, socket bool) DnstapIO { - return &dnstapIO{ +// New returns a new and initialized pointer to a dio. +func New(endpoint string, socket bool) *dio { + return &dio{ endpoint: endpoint, socket: socket, enc: newDnstapEncoder(&fs.EncoderOptions{ @@ -44,74 +50,65 @@ func New(endpoint string, socket bool) DnstapIO { } } -// DnstapIO interface -type DnstapIO interface { - Connect() - Dnstap(payload tap.Dnstap) - Close() -} - -func (dio *dnstapIO) newConnect() error { +func (d *dio) newConnect() error { var err error - if dio.socket { - if dio.conn, err = net.Dial("unix", dio.endpoint); err != nil { + if d.socket { + if d.conn, err = net.Dial("unix", d.endpoint); err != nil { return err } } else { - if dio.conn, err = net.DialTimeout("tcp", dio.endpoint, tcpTimeout); err != nil { + if d.conn, err = net.DialTimeout("tcp", d.endpoint, tcpTimeout); err != nil { return err } - if tcpConn, ok := dio.conn.(*net.TCPConn); ok { + if tcpConn, ok := d.conn.(*net.TCPConn); ok { tcpConn.SetWriteBuffer(tcpWriteBufSize) tcpConn.SetNoDelay(false) } } - return dio.enc.resetWriter(dio.conn) + return d.enc.resetWriter(d.conn) } // Connect connects to the dnstap endpoint. -func (dio *dnstapIO) Connect() { - if err := dio.newConnect(); err != nil { +func (d *dio) Connect() { + if err := d.newConnect(); err != nil { log.Error("No connection to dnstap endpoint") } - go dio.serve() + go d.serve() } // Dnstap enqueues the payload for log. -func (dio *dnstapIO) Dnstap(payload tap.Dnstap) { +func (d *dio) Dnstap(payload tap.Dnstap) { select { - case dio.queue <- payload: + case d.queue <- payload: default: - atomic.AddUint32(&dio.dropped, 1) + atomic.AddUint32(&d.dropped, 1) } } -func (dio *dnstapIO) closeConnection() { - dio.enc.close() - if dio.conn != nil { - dio.conn.Close() - dio.conn = nil +func (d *dio) closeConnection() { + d.enc.close() + if d.conn != nil { + d.conn.Close() + d.conn = nil } } // Close waits until the I/O routine is finished to return. -func (dio *dnstapIO) Close() { - close(dio.quit) -} +func (d *dio) Close() { close(d.quit) } -func (dio *dnstapIO) flushBuffer() { - if dio.conn == nil { - if err := dio.newConnect(); err != nil { +func (d *dio) flushBuffer() { + if d.conn == nil { + if err := d.newConnect(); err != nil { return } log.Info("Reconnected to dnstap") } - if err := dio.enc.flushBuffer(); err != nil { + if err := d.enc.flushBuffer(); err != nil { log.Warningf("Connection lost: %s", err) - dio.closeConnection() - if err := dio.newConnect(); err != nil { + d.closeConnection() + if err := d.newConnect(); err != nil { log.Errorf("Cannot connect to dnstap: %s", err) } else { log.Info("Reconnected to dnstap") @@ -119,27 +116,27 @@ func (dio *dnstapIO) flushBuffer() { } } -func (dio *dnstapIO) write(payload *tap.Dnstap) { - if err := dio.enc.writeMsg(payload); err != nil { - atomic.AddUint32(&dio.dropped, 1) +func (d *dio) write(payload *tap.Dnstap) { + if err := d.enc.writeMsg(payload); err != nil { + atomic.AddUint32(&d.dropped, 1) } } -func (dio *dnstapIO) serve() { +func (d *dio) serve() { timeout := time.After(flushTimeout) for { select { - case <-dio.quit: - dio.flushBuffer() - dio.closeConnection() + case <-d.quit: + d.flushBuffer() + d.closeConnection() return - case payload := <-dio.queue: - dio.write(&payload) + case payload := <-d.queue: + d.write(&payload) case <-timeout: - if dropped := atomic.SwapUint32(&dio.dropped, 0); dropped > 0 { + if dropped := atomic.SwapUint32(&d.dropped, 0); dropped > 0 { log.Warningf("Dropped dnstap messages: %d", dropped) } - dio.flushBuffer() + d.flushBuffer() timeout = time.After(flushTimeout) } } diff --git a/plugin/dnstap/dnstapio/io_test.go b/plugin/dnstap/dnstapio/io_test.go index 4716b4fd4..f26f50095 100644 --- a/plugin/dnstap/dnstapio/io_test.go +++ b/plugin/dnstap/dnstapio/io_test.go @@ -26,7 +26,6 @@ func accept(t *testing.T, l net.Listener, count int) { server, err := l.Accept() if err != nil { t.Fatalf("Server accepted: %s", err) - return } dec, err := fs.NewDecoder(server, &fs.DecoderOptions{ @@ -35,7 +34,6 @@ func accept(t *testing.T, l net.Listener, count int) { }) if err != nil { t.Fatalf("Server decoder: %s", err) - return } for i := 0; i < count; i++ { diff --git a/plugin/dnstap/dnstapio/log_test.go b/plugin/dnstap/dnstapio/log_test.go deleted file mode 100644 index c37b3df73..000000000 --- a/plugin/dnstap/dnstapio/log_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package dnstapio - -import clog "github.com/coredns/coredns/plugin/pkg/log" - -func init() { clog.Discard() } diff --git a/plugin/dnstap/gocontext.go b/plugin/dnstap/gocontext.go deleted file mode 100644 index a8cc2c2b4..000000000 --- a/plugin/dnstap/gocontext.go +++ /dev/null @@ -1,23 +0,0 @@ -package dnstap - -import "context" - -type contextKey struct{} - -var dnstapKey = contextKey{} - -// ContextWithTapper returns a new `context.Context` that holds a reference to -// `t`'s Tapper. -func ContextWithTapper(ctx context.Context, t Tapper) context.Context { - return context.WithValue(ctx, dnstapKey, t) -} - -// TapperFromContext returns the `Tapper` previously associated with `ctx`, or -// `nil` if no such `Tapper` could be found. -func TapperFromContext(ctx context.Context) Tapper { - val := ctx.Value(dnstapKey) - if sp, ok := val.(Tapper); ok { - return sp - } - return nil -} diff --git a/plugin/dnstap/handler.go b/plugin/dnstap/handler.go index 0dde3a346..7451d63f0 100644 --- a/plugin/dnstap/handler.go +++ b/plugin/dnstap/handler.go @@ -5,7 +5,7 @@ import ( "time" "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/dnstap/taprw" + "github.com/coredns/coredns/plugin/dnstap/dnstapio" tap "github.com/dnstap/golang-dnstap" "github.com/miekg/dns" @@ -14,75 +14,29 @@ import ( // Dnstap is the dnstap handler. type Dnstap struct { Next plugin.Handler - IO IORoutine + io dnstapio.Tapper - // Set to true to include the relevant raw DNS message into the dnstap messages. - JoinRawMessage bool + // IncludeRawMessage will include the raw DNS message into the dnstap messages if true. + IncludeRawMessage bool } -type ( - // IORoutine is the dnstap I/O thread as defined by: . - IORoutine interface { - Dnstap(tap.Dnstap) - } - // Tapper is implemented by the Context passed by the dnstap handler. - Tapper interface { - TapMessage(message *tap.Message) - Pack() bool - } -) - -// ContextKey defines the type of key that is used to save data into the context. -type ContextKey string - -const ( - // DnstapSendOption specifies the Dnstap message to be send. Default is sent all. - DnstapSendOption ContextKey = "dnstap-send-option" -) - -// TapMessage implements Tapper. +// TapMessage sends the message m to the dnstap interface. func (h Dnstap) TapMessage(m *tap.Message) { t := tap.Dnstap_MESSAGE - h.IO.Dnstap(tap.Dnstap{ - Type: &t, - Message: m, - }) -} - -// Pack returns true if the raw DNS message should be included into the dnstap messages. -func (h Dnstap) Pack() bool { - return h.JoinRawMessage + h.io.Dnstap(tap.Dnstap{Type: &t, Message: m}) } // ServeDNS logs the client query and response to dnstap and passes the dnstap Context. func (h Dnstap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - - // Add send option into context so other plugin can decide on which DNSTap - // message to be sent out - sendOption := taprw.SendOption{Cq: true, Cr: true} - newCtx := context.WithValue(ctx, DnstapSendOption, &sendOption) - newCtx = ContextWithTapper(newCtx, h) - - rw := &taprw.ResponseWriter{ + rw := &ResponseWriter{ ResponseWriter: w, - Tapper: &h, + Dnstap: h, Query: r, - Send: &sendOption, - QueryEpoch: time.Now(), + QueryTime: time.Now(), } - code, err := plugin.NextOrFailure(h.Name(), h.Next, newCtx, rw, r) - if err != nil { - // ignore dnstap errors - return code, err - } - - if err = rw.DnstapError(); err != nil { - return code, plugin.Error("dnstap", err) - } - - return code, nil + return plugin.NextOrFailure(h.Name(), h.Next, ctx, rw, r) } -// Name returns dnstap. +// Name implements the plugin.Plugin interface. func (h Dnstap) Name() string { return "dnstap" } diff --git a/plugin/dnstap/handler_test.go b/plugin/dnstap/handler_test.go index b86fe019d..acfbc8770 100644 --- a/plugin/dnstap/handler_test.go +++ b/plugin/dnstap/handler_test.go @@ -2,14 +2,11 @@ package dnstap import ( "context" - "errors" "net" - "strings" "testing" - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/dnstap/test" - mwtest "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/plugin/dnstap/msg" + test "github.com/coredns/coredns/plugin/test" tap "github.com/dnstap/golang-dnstap" "github.com/miekg/dns" ) @@ -18,15 +15,14 @@ func testCase(t *testing.T, tapq, tapr *tap.Message, q, r *dns.Msg) { w := writer{t: t} w.queue = append(w.queue, tapq, tapr) h := Dnstap{ - Next: mwtest.HandlerFunc(func(_ context.Context, + Next: test.HandlerFunc(func(_ context.Context, w dns.ResponseWriter, _ *dns.Msg) (int, error) { return 0, w.WriteMsg(r) }), - IO: &w, - JoinRawMessage: false, + io: &w, } - _, err := h.ServeDNS(context.TODO(), &mwtest.ResponseWriter{}, q) + _, err := h.ServeDNS(context.TODO(), &test.ResponseWriter{}, q) if err != nil { t.Fatal(err) } @@ -39,78 +35,50 @@ type writer struct { func (w *writer) Dnstap(e tap.Dnstap) { if len(w.queue) == 0 { - w.t.Error("Message not expected.") + w.t.Error("Message not expected") } - if !test.MsgEqual(w.queue[0], e.Message) { - w.t.Errorf("Want: %v, have: %v", w.queue[0], e.Message) + + ex := w.queue[0] + got := e.Message + + if string(ex.QueryAddress) != string(got.QueryAddress) { + w.t.Errorf("Expected source adress %s, got %s", ex.QueryAddress, got.QueryAddress) + } + if string(ex.ResponseAddress) != string(got.ResponseAddress) { + w.t.Errorf("Expected response adress %s, got %s", ex.ResponseAddress, got.ResponseAddress) + } + if *ex.QueryPort != *got.QueryPort { + w.t.Errorf("Expected port %d, got %d", *ex.QueryPort, *got.QueryPort) + } + if *ex.SocketFamily != *got.SocketFamily { + w.t.Errorf("Expected socket family %d, got %d", *ex.SocketFamily, *got.SocketFamily) } w.queue = w.queue[1:] } func TestDnstap(t *testing.T) { - q := mwtest.Case{Qname: "example.org", Qtype: dns.TypeA}.Msg() - r := mwtest.Case{ + q := test.Case{Qname: "example.org", Qtype: dns.TypeA}.Msg() + r := test.Case{ Qname: "example.org.", Qtype: dns.TypeA, Answer: []dns.RR{ - mwtest.A("example.org. 3600 IN A 10.0.0.1"), + test.A("example.org. 3600 IN A 10.0.0.1"), }, }.Msg() - tapq, _ := test.TestingData().ToClientQuery() - tapr, _ := test.TestingData().ToClientResponse() + tapq := testMessage() // leave type unset for deepEqual + msg.SetType(tapq, tap.Message_CLIENT_QUERY) + tapr := testMessage() + msg.SetType(tapr, tap.Message_CLIENT_RESPONSE) testCase(t, tapq, tapr, q, r) } -type noWriter struct { -} - -func (n noWriter) Dnstap(d tap.Dnstap) { -} - -func endWith(c int, err error) plugin.Handler { - return mwtest.HandlerFunc(func(_ context.Context, w dns.ResponseWriter, _ *dns.Msg) (int, error) { - w.WriteMsg(nil) // trigger plugin dnstap to log client query and response - // maybe dnstap should log the client query when no message is written... - return c, err - }) -} - -type badAddr struct { -} - -func (bad badAddr) Network() string { - return "bad network" -} -func (bad badAddr) String() string { - return "bad address" -} - -type badRW struct { - dns.ResponseWriter -} - -func (bad *badRW) RemoteAddr() net.Addr { - return badAddr{} -} - -func TestError(t *testing.T) { - h := Dnstap{ - Next: endWith(0, nil), - IO: noWriter{}, - JoinRawMessage: false, - } - rw := &badRW{&mwtest.ResponseWriter{}} - - // the dnstap error will show only if there is no plugin error - _, err := h.ServeDNS(context.TODO(), rw, nil) - if err == nil || !strings.HasPrefix(err.Error(), "plugin/dnstap") { - t.Fatal("Must return the dnstap error but have:", err) - } - - // plugin errors will always overwrite dnstap errors - pluginErr := errors.New("plugin error") - h.Next = endWith(0, pluginErr) - _, err = h.ServeDNS(context.TODO(), rw, nil) - if err != pluginErr { - t.Fatal("Must return the plugin error but have:", err) +func testMessage() *tap.Message { + inet := tap.SocketFamily_INET + udp := tap.SocketProtocol_UDP + port := uint32(40212) + return &tap.Message{ + SocketFamily: &inet, + SocketProtocol: &udp, + QueryAddress: net.ParseIP("10.240.0.1"), + QueryPort: &port, } } diff --git a/plugin/dnstap/msg/msg.go b/plugin/dnstap/msg/msg.go index d96fc6c9a..f9d84c45a 100644 --- a/plugin/dnstap/msg/msg.go +++ b/plugin/dnstap/msg/msg.go @@ -1,159 +1,97 @@ package msg import ( - "errors" + "fmt" "net" - "strconv" "time" tap "github.com/dnstap/golang-dnstap" - "github.com/miekg/dns" ) -// Builder helps to build a Dnstap message. -type Builder struct { - Packed []byte - SocketProto tap.SocketProtocol - SocketFam tap.SocketFamily - Address net.IP - Port uint32 - TimeSec uint64 - TimeNsec uint32 +var ( + protoUDP = tap.SocketProtocol_UDP + protoTCP = tap.SocketProtocol_TCP + familyINET = tap.SocketFamily_INET + familyINET6 = tap.SocketFamily_INET6 +) - err error -} - -// New returns a new Builder -func New() *Builder { - return &Builder{} -} - -// Addr adds the remote address to the message. -func (b *Builder) Addr(remote net.Addr) *Builder { - if b.err != nil { - return b - } - - switch addr := remote.(type) { +// SetQueryAddress adds the query address to the message. This also sets the SocketFamily and SocketProtocol. +func SetQueryAddress(t *tap.Message, addr net.Addr) error { + t.SocketFamily = &familyINET + switch a := addr.(type) { case *net.TCPAddr: - b.Address = addr.IP - b.Port = uint32(addr.Port) - b.SocketProto = tap.SocketProtocol_TCP - case *net.UDPAddr: - b.Address = addr.IP - b.Port = uint32(addr.Port) - b.SocketProto = tap.SocketProtocol_UDP - default: - b.err = errors.New("unknown remote address type") - return b - } + t.SocketProtocol = &protoTCP + t.QueryAddress = a.IP - if b.Address.To4() != nil { - b.SocketFam = tap.SocketFamily_INET - } else { - b.SocketFam = tap.SocketFamily_INET6 - } - return b -} + p := uint32(a.Port) + t.QueryPort = &p -// Msg adds the raw DNS message to the dnstap message. -func (b *Builder) Msg(m *dns.Msg) *Builder { - if b.err != nil { - return b - } - - b.Packed, b.err = m.Pack() - return b -} - -// HostPort adds the remote address as encoded by dnsutil.ParseHostPortOrFile to the message. -func (b *Builder) HostPort(addr string) *Builder { - ip, port, err := net.SplitHostPort(addr) - if err != nil { - b.err = err - return b - } - p, err := strconv.ParseUint(port, 10, 32) - if err != nil { - b.err = err - return b - } - b.Port = uint32(p) - - if ip := net.ParseIP(ip); ip != nil { - b.Address = []byte(ip) - if ip := ip.To4(); ip != nil { - b.SocketFam = tap.SocketFamily_INET - } else { - b.SocketFam = tap.SocketFamily_INET6 + if a.IP.To4() == nil { + t.SocketFamily = &familyINET6 } - return b + return nil + case *net.UDPAddr: + t.SocketProtocol = &protoUDP + t.QueryAddress = a.IP + + p := uint32(a.Port) + t.QueryPort = &p + + if a.IP.To4() == nil { + t.SocketFamily = &familyINET6 + } + return nil + default: + return fmt.Errorf("unknown address type: %T", a) } - b.err = errors.New("not an ip address") - return b } -// Time adds the timestamp to the message. -func (b *Builder) Time(ts time.Time) *Builder { - b.TimeSec = uint64(ts.Unix()) - b.TimeNsec = uint32(ts.Nanosecond()) - return b +// SetResponseAddress the response address to the message. This also sets the SocketFamily and SocketProtocol. +func SetResponseAddress(t *tap.Message, addr net.Addr) error { + t.SocketFamily = &familyINET + switch a := addr.(type) { + case *net.TCPAddr: + t.SocketProtocol = &protoTCP + t.ResponseAddress = a.IP + + p := uint32(a.Port) + t.ResponsePort = &p + + if a.IP.To4() == nil { + t.SocketFamily = &familyINET6 + } + return nil + case *net.UDPAddr: + t.SocketProtocol = &protoUDP + t.ResponseAddress = a.IP + + p := uint32(a.Port) + t.ResponsePort = &p + + if a.IP.To4() == nil { + t.SocketFamily = &familyINET6 + } + return nil + default: + return fmt.Errorf("unknown address type: %T", a) + } } -// ToClientResponse transforms Data into a client response message. -func (b *Builder) ToClientResponse() (*tap.Message, error) { - t := tap.Message_CLIENT_RESPONSE - return &tap.Message{ - Type: &t, - SocketFamily: &b.SocketFam, - SocketProtocol: &b.SocketProto, - ResponseTimeSec: &b.TimeSec, - ResponseTimeNsec: &b.TimeNsec, - ResponseMessage: b.Packed, - QueryAddress: b.Address, - QueryPort: &b.Port, - }, b.err +// SetQueryTime sets the time of the query in t. +func SetQueryTime(t *tap.Message, ti time.Time) { + qts := uint64(ti.Unix()) + qtn := uint32(ti.Nanosecond()) + t.QueryTimeSec = &qts + t.QueryTimeNsec = &qtn } -// ToClientQuery transforms Data into a client query message. -func (b *Builder) ToClientQuery() (*tap.Message, error) { - t := tap.Message_CLIENT_QUERY - return &tap.Message{ - Type: &t, - SocketFamily: &b.SocketFam, - SocketProtocol: &b.SocketProto, - QueryTimeSec: &b.TimeSec, - QueryTimeNsec: &b.TimeNsec, - QueryMessage: b.Packed, - QueryAddress: b.Address, - QueryPort: &b.Port, - }, b.err +// SetResponseTime sets the time of the response in t. +func SetResponseTime(t *tap.Message, ti time.Time) { + rts := uint64(ti.Unix()) + rtn := uint32(ti.Nanosecond()) + t.ResponseTimeSec = &rts + t.ResponseTimeNsec = &rtn } -// ToOutsideQuery transforms the data into a forwarder or resolver query message. -func (b *Builder) ToOutsideQuery(t tap.Message_Type) (*tap.Message, error) { - return &tap.Message{ - Type: &t, - SocketFamily: &b.SocketFam, - SocketProtocol: &b.SocketProto, - QueryTimeSec: &b.TimeSec, - QueryTimeNsec: &b.TimeNsec, - QueryMessage: b.Packed, - ResponseAddress: b.Address, - ResponsePort: &b.Port, - }, b.err -} - -// ToOutsideResponse transforms the data into a forwarder or resolver response message. -func (b *Builder) ToOutsideResponse(t tap.Message_Type) (*tap.Message, error) { - return &tap.Message{ - Type: &t, - SocketFamily: &b.SocketFam, - SocketProtocol: &b.SocketProto, - ResponseTimeSec: &b.TimeSec, - ResponseTimeNsec: &b.TimeNsec, - ResponseMessage: b.Packed, - ResponseAddress: b.Address, - ResponsePort: &b.Port, - }, b.err -} +// SetType sets the type in t. +func SetType(t *tap.Message, typ tap.Message_Type) { t.Type = &typ } diff --git a/plugin/dnstap/msg/msg_test.go b/plugin/dnstap/msg/msg_test.go deleted file mode 100644 index 57a4e4fe0..000000000 --- a/plugin/dnstap/msg/msg_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package msg - -import ( - "net" - "reflect" - "testing" - - "github.com/coredns/coredns/plugin/test" - "github.com/coredns/coredns/request" - - tap "github.com/dnstap/golang-dnstap" - "github.com/miekg/dns" -) - -func testRequest(t *testing.T, expected Builder, r request.Request) { - d := Builder{} - d.Addr(r.W.RemoteAddr()) - if d.SocketProto != expected.SocketProto || - d.SocketFam != expected.SocketFam || - !reflect.DeepEqual(d.Address, expected.Address) || - d.Port != expected.Port { - t.Fatalf("Expected: %v, have: %v", expected, d) - return - } -} -func TestRequest(t *testing.T) { - testRequest(t, Builder{ - SocketProto: tap.SocketProtocol_UDP, - SocketFam: tap.SocketFamily_INET, - Address: net.ParseIP("10.240.0.1"), - Port: 40212, - }, testingRequest()) -} -func testingRequest() request.Request { - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) - m.SetEdns0(4097, true) - return request.Request{W: &test.ResponseWriter{}, Req: m} -} diff --git a/plugin/dnstap/setup.go b/plugin/dnstap/setup.go index e1589798b..ab5488686 100644 --- a/plugin/dnstap/setup.go +++ b/plugin/dnstap/setup.go @@ -10,14 +10,7 @@ import ( "github.com/coredns/coredns/plugin/pkg/parse" ) -func init() { plugin.Register("dnstap", wrapSetup) } - -func wrapSetup(c *caddy.Controller) error { - if err := setup(c); err != nil { - return plugin.Error("dnstap", err) - } - return nil -} +func init() { plugin.Register("dnstap", setup) } type config struct { target string @@ -53,11 +46,11 @@ func parseConfig(d *caddy.Controller) (c config, err error) { func setup(c *caddy.Controller) error { conf, err := parseConfig(c) if err != nil { - return err + return plugin.Error("dnstap", err) } dio := dnstapio.New(conf.target, conf.socket) - dnstap := Dnstap{IO: dio, JoinRawMessage: conf.full} + dnstap := Dnstap{io: dio, IncludeRawMessage: conf.full} c.OnStartup(func() error { dio.Connect() diff --git a/plugin/dnstap/taprw/writer.go b/plugin/dnstap/taprw/writer.go deleted file mode 100644 index 05cf095b9..000000000 --- a/plugin/dnstap/taprw/writer.go +++ /dev/null @@ -1,79 +0,0 @@ -// Package taprw takes a query and intercepts the response. -// It will log both after the response is written. -package taprw - -import ( - "fmt" - "time" - - "github.com/coredns/coredns/plugin/dnstap/msg" - - tap "github.com/dnstap/golang-dnstap" - "github.com/miekg/dns" -) - -// SendOption stores the flag to indicate whether a certain DNSTap message to -// be sent out or not. -type SendOption struct { - Cq bool - Cr bool -} - -// Tapper is what ResponseWriter needs to log to dnstap. -type Tapper interface { - TapMessage(*tap.Message) - Pack() bool -} - -// ResponseWriter captures the client response and logs the query to dnstap. -// Single request use. -// SendOption configures Dnstap to selectively send Dnstap messages. Default is send all. -type ResponseWriter struct { - QueryEpoch time.Time - Query *dns.Msg - dns.ResponseWriter - Tapper - Send *SendOption - - dnstapErr error -} - -// DnstapError checks if a dnstap error occurred during Write and returns it. -func (w *ResponseWriter) DnstapError() error { - return w.dnstapErr -} - -// WriteMsg writes back the response to the client and THEN works on logging the request -// and response to dnstap. -func (w *ResponseWriter) WriteMsg(resp *dns.Msg) (writeErr error) { - writeErr = w.ResponseWriter.WriteMsg(resp) - writeEpoch := time.Now() - - b := msg.New().Time(w.QueryEpoch).Addr(w.RemoteAddr()) - - if w.Send == nil || w.Send.Cq { - if w.Pack() { - b.Msg(w.Query) - } - if m, err := b.ToClientQuery(); err != nil { - w.dnstapErr = fmt.Errorf("client query: %s", err) - } else { - w.TapMessage(m) - } - } - - if w.Send == nil || w.Send.Cr { - if writeErr == nil { - if w.Pack() { - b.Msg(resp) - } - if m, err := b.Time(writeEpoch).ToClientResponse(); err != nil { - w.dnstapErr = fmt.Errorf("client response: %s", err) - } else { - w.TapMessage(m) - } - } - } - - return writeErr -} diff --git a/plugin/dnstap/taprw/writer_test.go b/plugin/dnstap/taprw/writer_test.go deleted file mode 100644 index d55943894..000000000 --- a/plugin/dnstap/taprw/writer_test.go +++ /dev/null @@ -1,115 +0,0 @@ -package taprw - -import ( - "testing" - - "github.com/coredns/coredns/plugin/dnstap/test" - mwtest "github.com/coredns/coredns/plugin/test" - - "github.com/miekg/dns" -) - -func testingMsg() (m *dns.Msg) { - m = new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) - m.SetEdns0(4097, true) - return -} - -func TestClientQueryResponse(t *testing.T) { - trapper := test.TrapTapper{Full: true} - m := testingMsg() - rw := ResponseWriter{ - Query: m, - Tapper: &trapper, - ResponseWriter: &mwtest.ResponseWriter{}, - } - d := test.TestingData() - - // will the wire-format msg be reported? - bin, err := m.Pack() - if err != nil { - t.Fatal(err) - return - } - d.Packed = bin - - if err := rw.WriteMsg(m); err != nil { - t.Fatal(err) - return - } - if l := len(trapper.Trap); l != 2 { - t.Fatalf("Mmsg %d trapped", l) - return - } - want, err := d.ToClientQuery() - if err != nil { - t.Fatal("Testing data must build", err) - } - have := trapper.Trap[0] - if !test.MsgEqual(want, have) { - t.Fatalf("Query: want: %v\nhave: %v", want, have) - } - want, err = d.ToClientResponse() - if err != nil { - t.Fatal("Testing data must build", err) - } - have = trapper.Trap[1] - if !test.MsgEqual(want, have) { - t.Fatalf("Response: want: %v\nhave: %v", want, have) - } -} - -func TestClientQueryResponseWithSendOption(t *testing.T) { - trapper := test.TrapTapper{Full: true} - m := testingMsg() - rw := ResponseWriter{ - Query: m, - Tapper: &trapper, - ResponseWriter: &mwtest.ResponseWriter{}, - } - d := test.TestingData() - bin, err := m.Pack() - if err != nil { - t.Fatal(err) - return - } - d.Packed = bin - - // Do not send both CQ and CR - o := SendOption{Cq: false, Cr: false} - rw.Send = &o - - if err := rw.WriteMsg(m); err != nil { - t.Fatal(err) - return - } - if l := len(trapper.Trap); l != 0 { - t.Fatalf("%d msg trapped", l) - return - } - - //Send CQ - o.Cq = true - if err := rw.WriteMsg(m); err != nil { - t.Fatal(err) - return - } - if l := len(trapper.Trap); l != 1 { - t.Fatalf("%d msg trapped", l) - return - } - - //Send CR - trapper.Trap = trapper.Trap[:0] - o.Cq = false - o.Cr = true - if err := rw.WriteMsg(m); err != nil { - t.Fatal(err) - return - } - if l := len(trapper.Trap); l != 1 { - t.Fatalf("%d msg trapped", l) - return - } -} diff --git a/plugin/dnstap/test/helpers.go b/plugin/dnstap/test/helpers.go deleted file mode 100644 index 5f498d59f..000000000 --- a/plugin/dnstap/test/helpers.go +++ /dev/null @@ -1,72 +0,0 @@ -package test - -import ( - "net" - "reflect" - - "github.com/coredns/coredns/plugin/dnstap/msg" - - tap "github.com/dnstap/golang-dnstap" -) - -// TestingData returns the Data matching coredns/test.ResponseWriter. -func TestingData() (d *msg.Builder) { - d = &msg.Builder{ - SocketFam: tap.SocketFamily_INET, - SocketProto: tap.SocketProtocol_UDP, - Address: net.ParseIP("10.240.0.1"), - Port: 40212, - } - return -} - -type comp struct { - Type *tap.Message_Type - SF *tap.SocketFamily - SP *tap.SocketProtocol - QA []byte - RA []byte - QP *uint32 - RP *uint32 - QTSec bool - RTSec bool - RM []byte - QM []byte -} - -func toComp(m *tap.Message) comp { - return comp{ - Type: m.Type, - SF: m.SocketFamily, - SP: m.SocketProtocol, - QA: m.QueryAddress, - RA: m.ResponseAddress, - QP: m.QueryPort, - RP: m.ResponsePort, - QTSec: m.QueryTimeSec != nil, - RTSec: m.ResponseTimeSec != nil, - RM: m.ResponseMessage, - QM: m.QueryMessage, - } -} - -// MsgEqual compares two dnstap messages ignoring timestamps. -func MsgEqual(a, b *tap.Message) bool { - return reflect.DeepEqual(toComp(a), toComp(b)) -} - -// TrapTapper traps messages. -type TrapTapper struct { - Trap []*tap.Message - Full bool -} - -// Pack returns field Full. -func (t *TrapTapper) Pack() bool { - return t.Full -} - -// TapMessage adds the message to the trap. -func (t *TrapTapper) TapMessage(m *tap.Message) { - t.Trap = append(t.Trap, m) -} diff --git a/plugin/dnstap/writer.go b/plugin/dnstap/writer.go new file mode 100644 index 000000000..315a3a790 --- /dev/null +++ b/plugin/dnstap/writer.go @@ -0,0 +1,53 @@ +package dnstap + +import ( + "time" + + "github.com/coredns/coredns/plugin/dnstap/msg" + tap "github.com/dnstap/golang-dnstap" + "github.com/miekg/dns" +) + +// ResponseWriter captures the client response and logs the query to dnstap. +// Single request use. +type ResponseWriter struct { + QueryTime time.Time + Query *dns.Msg + dns.ResponseWriter + Dnstap +} + +// WriteMsg writes back the response to the client and THEN works on logging the request +// and response to dnstap. +func (w *ResponseWriter) WriteMsg(resp *dns.Msg) error { + err := w.ResponseWriter.WriteMsg(resp) + + q := new(tap.Message) + msg.SetQueryTime(q, w.QueryTime) + msg.SetQueryAddress(q, w.RemoteAddr()) + + if w.IncludeRawMessage { + buf, _ := w.Query.Pack() + q.QueryMessage = buf + } + msg.SetType(q, tap.Message_CLIENT_QUERY) + w.TapMessage(q) + + if err != nil { + return err + } + + r := new(tap.Message) + msg.SetQueryTime(r, w.QueryTime) + msg.SetResponseTime(r, time.Now()) + msg.SetQueryAddress(r, w.RemoteAddr()) + + if w.IncludeRawMessage { + buf, _ := resp.Pack() + r.ResponseMessage = buf + } + + msg.SetType(r, tap.Message_CLIENT_RESPONSE) + w.TapMessage(r) + return nil +} diff --git a/plugin/forward/dnstap.go b/plugin/forward/dnstap.go index 7866aa39b..e005cc02a 100644 --- a/plugin/forward/dnstap.go +++ b/plugin/forward/dnstap.go @@ -1,10 +1,10 @@ package forward import ( - "context" + "net" + "strconv" "time" - "github.com/coredns/coredns/plugin/dnstap" "github.com/coredns/coredns/plugin/dnstap/msg" "github.com/coredns/coredns/request" @@ -12,50 +12,48 @@ import ( "github.com/miekg/dns" ) -func toDnstap(ctx context.Context, host string, f *Forward, state request.Request, reply *dns.Msg, start time.Time) error { - tapper := dnstap.TapperFromContext(ctx) - if tapper == nil { - return nil - } +// toDnstap will send the forward and received message to the dnstap plugin. +func toDnstap(f *Forward, host string, state request.Request, opts options, reply *dns.Msg, start time.Time) { // Query - b := msg.New().Time(start).HostPort(host) - opts := f.opts - t := "" + q := new(tap.Message) + msg.SetQueryTime(q, start) + h, p, _ := net.SplitHostPort(host) // this is preparsed and can't err here + port, _ := strconv.ParseUint(p, 10, 32) // same here + ip := net.ParseIP(h) + + var ta net.Addr = &net.UDPAddr{IP: ip, Port: int(port)} + t := state.Proto() switch { - case opts.forceTCP: // TCP flag has precedence over UDP flag + case opts.forceTCP: t = "tcp" case opts.preferUDP: t = "udp" - default: - t = state.Proto() } if t == "tcp" { - b.SocketProto = tap.SocketProtocol_TCP - } else { - b.SocketProto = tap.SocketProtocol_UDP + ta = &net.TCPAddr{IP: ip, Port: int(port)} } - if tapper.Pack() { - b.Msg(state.Req) + msg.SetQueryAddress(q, ta) + + if f.tapPlugin.IncludeRawMessage { + buf, _ := state.Req.Pack() + q.QueryMessage = buf } - m, err := b.ToOutsideQuery(tap.Message_FORWARDER_QUERY) - if err != nil { - return err - } - tapper.TapMessage(m) + msg.SetType(q, tap.Message_FORWARDER_QUERY) + f.tapPlugin.TapMessage(q) // Response if reply != nil { - if tapper.Pack() { - b.Msg(reply) + r := new(tap.Message) + if f.tapPlugin.IncludeRawMessage { + buf, _ := reply.Pack() + r.ResponseMessage = buf } - m, err := b.Time(time.Now()).ToOutsideResponse(tap.Message_FORWARDER_RESPONSE) - if err != nil { - return err - } - tapper.TapMessage(m) + msg.SetQueryTime(r, start) + msg.SetQueryAddress(r, ta) + msg.SetResponseTime(r, time.Now()) + msg.SetType(r, tap.Message_FORWARDER_RESPONSE) + f.tapPlugin.TapMessage(r) } - - return nil } diff --git a/plugin/forward/dnstap_test.go b/plugin/forward/dnstap_test.go deleted file mode 100644 index c86ee8c75..000000000 --- a/plugin/forward/dnstap_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package forward - -import ( - "context" - "testing" - "time" - - "github.com/coredns/coredns/plugin/dnstap" - "github.com/coredns/coredns/plugin/dnstap/msg" - "github.com/coredns/coredns/plugin/dnstap/test" - mwtest "github.com/coredns/coredns/plugin/test" - "github.com/coredns/coredns/request" - - tap "github.com/dnstap/golang-dnstap" - "github.com/miekg/dns" -) - -func testCase(t *testing.T, f *Forward, q, r *dns.Msg, datq, datr *msg.Builder) { - tapq, _ := datq.ToOutsideQuery(tap.Message_FORWARDER_QUERY) - tapr, _ := datr.ToOutsideResponse(tap.Message_FORWARDER_RESPONSE) - tapper := test.TrapTapper{} - ctx := dnstap.ContextWithTapper(context.TODO(), &tapper) - err := toDnstap(ctx, "10.240.0.1:40212", f, - request.Request{W: &mwtest.ResponseWriter{}, Req: q}, r, time.Now()) - if err != nil { - t.Fatal(err) - } - if len(tapper.Trap) != 2 { - t.Fatalf("Messages: %d", len(tapper.Trap)) - } - if !test.MsgEqual(tapper.Trap[0], tapq) { - t.Errorf("Want: %v\nhave: %v", tapq, tapper.Trap[0]) - } - if !test.MsgEqual(tapper.Trap[1], tapr) { - t.Errorf("Want: %v\nhave: %v", tapr, tapper.Trap[1]) - } -} - -func TestDnstap(t *testing.T) { - q := mwtest.Case{Qname: "example.org", Qtype: dns.TypeA}.Msg() - r := mwtest.Case{ - Qname: "example.org.", Qtype: dns.TypeA, - Answer: []dns.RR{ - mwtest.A("example.org. 3600 IN A 10.0.0.1"), - }, - }.Msg() - tapq, tapr := test.TestingData(), test.TestingData() - fu := New() - fu.opts.preferUDP = true - testCase(t, fu, q, r, tapq, tapr) - tapq.SocketProto = tap.SocketProtocol_TCP - tapr.SocketProto = tap.SocketProtocol_TCP - ft := New() - ft.opts.forceTCP = true - testCase(t, ft, q, r, tapq, tapr) -} - -func TestNoDnstap(t *testing.T) { - err := toDnstap(context.TODO(), "", nil, request.Request{}, nil, time.Now()) - if err != nil { - t.Fatal(err) - } -} diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index 673d04732..e9a180cb6 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -13,6 +13,7 @@ import ( "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/debug" + "github.com/coredns/coredns/plugin/dnstap" clog "github.com/coredns/coredns/plugin/pkg/log" "github.com/coredns/coredns/request" @@ -46,6 +47,8 @@ type Forward struct { // the maximum allowed (maxConcurrent) ErrLimitExceeded error + tapPlugin *dnstap.Dnstap // when the dnstap plugin is loaded, we use to this to send messages out. + Next plugin.Handler } @@ -140,7 +143,10 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg if child != nil { child.Finish() } - taperr := toDnstap(ctx, proxy.addr, f, state, ret, start) + + if f.tapPlugin != nil { + toDnstap(f, proxy.addr, state, opts, ret, start) + } upstreamErr = err @@ -163,11 +169,11 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg formerr := new(dns.Msg) formerr.SetRcode(state.Req, dns.RcodeFormatError) w.WriteMsg(formerr) - return 0, taperr + return 0, nil } w.WriteMsg(ret) - return 0, taperr + return 0, nil } if upstreamErr != nil { diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 125d53e4c..7504e9409 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -10,6 +10,7 @@ import ( "github.com/coredns/caddy" "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/dnstap" "github.com/coredns/coredns/plugin/pkg/parse" pkgtls "github.com/coredns/coredns/plugin/pkg/tls" "github.com/coredns/coredns/plugin/pkg/transport" @@ -34,6 +35,14 @@ func setup(c *caddy.Controller) error { c.OnStartup(func() error { return f.OnStartup() }) + c.OnStartup(func() error { + if taph := dnsserver.GetConfig(c).Handler("dnstap"); taph != nil { + if tapPlugin, ok := taph.(dnstap.Dnstap); ok { + f.tapPlugin = &tapPlugin + } + } + return nil + }) c.OnShutdown(func() error { return f.OnShutdown()