diff --git a/middleware/autopath/autopath.go b/middleware/autopath/autopath.go index ec6bb674b..6779efa18 100644 --- a/middleware/autopath/autopath.go +++ b/middleware/autopath/autopath.go @@ -36,6 +36,7 @@ import ( "github.com/coredns/coredns/middleware" "github.com/coredns/coredns/middleware/pkg/dnsutil" + "github.com/coredns/coredns/middleware/pkg/nonwriter" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -101,7 +102,7 @@ func (a *AutoPath) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms for i, s := range searchpath { newQName := base + "." + s ar.Question[0].Name = newQName - nw := NewNonWriter(w) + nw := nonwriter.New(w) rcode, err := middleware.NextOrFailure(a.Name(), a.Next, ctx, nw, ar) if err != nil { diff --git a/middleware/autopath/nonwriter.go b/middleware/autopath/nonwriter.go deleted file mode 100644 index 0d4c98119..000000000 --- a/middleware/autopath/nonwriter.go +++ /dev/null @@ -1,22 +0,0 @@ -package autopath - -import ( - "github.com/miekg/dns" -) - -// NonWriter is a type of ResponseWriter that captures the message, but never writes to the client. -type NonWriter struct { - dns.ResponseWriter - Msg *dns.Msg -} - -// NewNonWriter makes and returns a new NonWriter. -func NewNonWriter(w dns.ResponseWriter) *NonWriter { return &NonWriter{ResponseWriter: w} } - -// WriteMsg records the message, but doesn't write it itself. -func (r *NonWriter) WriteMsg(res *dns.Msg) error { - r.Msg = res - return nil -} - -func (r *NonWriter) Write(buf []byte) (int, error) { return len(buf), nil } diff --git a/middleware/federation/federation.go b/middleware/federation/federation.go index 16c698ef8..724dcb656 100644 --- a/middleware/federation/federation.go +++ b/middleware/federation/federation.go @@ -17,6 +17,7 @@ import ( "github.com/coredns/coredns/middleware" "github.com/coredns/coredns/middleware/etcd/msg" "github.com/coredns/coredns/middleware/pkg/dnsutil" + "github.com/coredns/coredns/middleware/pkg/nonwriter" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -67,7 +68,7 @@ func (f *Federation) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns. // Start the next middleware, but with a nowriter, capture the result, if NXDOMAIN // perform federation, otherwise just write the result. - nw := NewNonWriter(w) + nw := nonwriter.New(w) ret, err := middleware.NextOrFailure(f.Name(), f.Next, ctx, nw, r) if !middleware.ClientWrite(ret) { diff --git a/middleware/federation/nonwriter.go b/middleware/federation/nonwriter.go deleted file mode 100644 index c60fb1075..000000000 --- a/middleware/federation/nonwriter.go +++ /dev/null @@ -1,22 +0,0 @@ -package federation - -import ( - "github.com/miekg/dns" -) - -// NonWriter is a type of ResponseWriter that captures the message, but never writes to the client. -type NonWriter struct { - dns.ResponseWriter - Msg *dns.Msg -} - -// NewNonWriter makes and returns a new NonWriter. -func NewNonWriter(w dns.ResponseWriter) *NonWriter { return &NonWriter{ResponseWriter: w} } - -// WriteMsg records the message, but doesn't write it itself. -func (r *NonWriter) WriteMsg(res *dns.Msg) error { - r.Msg = res - return nil -} - -func (r *NonWriter) Write(buf []byte) (int, error) { return len(buf), nil } diff --git a/middleware/pkg/nonwriter/nonwriter.go b/middleware/pkg/nonwriter/nonwriter.go new file mode 100644 index 000000000..7819a320f --- /dev/null +++ b/middleware/pkg/nonwriter/nonwriter.go @@ -0,0 +1,23 @@ +// Package nonwriter implements a dns.ResponseWriter that never writes, but captures the dns.Msg being written. +package nonwriter + +import ( + "github.com/miekg/dns" +) + +// Writer is a type of ResponseWriter that captures the message, but never writes to the client. +type Writer struct { + dns.ResponseWriter + Msg *dns.Msg +} + +// New makes and returns a new NonWriter. +func New(w dns.ResponseWriter) *Writer { return &Writer{ResponseWriter: w} } + +// WriteMsg records the message, but doesn't write it itself. +func (w *Writer) WriteMsg(res *dns.Msg) error { + w.Msg = res + return nil +} + +func (w *Writer) Write(buf []byte) (int, error) { return len(buf), nil } diff --git a/middleware/pkg/nonwriter/nonwriter_test.go b/middleware/pkg/nonwriter/nonwriter_test.go new file mode 100644 index 000000000..d8433af55 --- /dev/null +++ b/middleware/pkg/nonwriter/nonwriter_test.go @@ -0,0 +1,19 @@ +package nonwriter + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestNonWriter(t *testing.T) { + nw := New(nil) + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + if err := nw.WriteMsg(m); err != nil { + t.Errorf("Got error when writing to nonwriter: %s", err) + } + if x := nw.Msg.Question[0].Name; x != "example.org." { + t.Errorf("Expacted 'example.org.' got %q:", x) + } +}