diff --git a/plugin/pkg/dnstest/multirecorder.go b/plugin/pkg/dnstest/multirecorder.go new file mode 100644 index 000000000..fe8ee03ac --- /dev/null +++ b/plugin/pkg/dnstest/multirecorder.go @@ -0,0 +1,41 @@ +package dnstest + +import ( + "time" + + "github.com/miekg/dns" +) + +// MultiRecorder is a type of ResponseWriter that captures all messages written to it. +type MultiRecorder struct { + Len int + Msgs []*dns.Msg + Start time.Time + dns.ResponseWriter +} + +// NewMultiRecorder makes and returns a new MultiRecorder. +func NewMultiRecorder(w dns.ResponseWriter) *MultiRecorder { + return &MultiRecorder{ + ResponseWriter: w, + Msgs: make([]*dns.Msg, 0), + Start: time.Now(), + } +} + +// WriteMsg records the message and its length written to it and call the +// underlying ResponseWriter's WriteMsg method. +func (r *MultiRecorder) WriteMsg(res *dns.Msg) error { + r.Len += res.Len() + r.Msgs = append(r.Msgs, res) + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the length of the messages that get written to it. +func (r *MultiRecorder) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + if err == nil { + r.Len += n + } + return n, err +} diff --git a/plugin/pkg/dnstest/multirecorder_test.go b/plugin/pkg/dnstest/multirecorder_test.go new file mode 100644 index 000000000..756b635ac --- /dev/null +++ b/plugin/pkg/dnstest/multirecorder_test.go @@ -0,0 +1,39 @@ +package dnstest + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestMultiWriteMsg(t *testing.T) { + w := &responseWriter{} + record := NewMultiRecorder(w) + + responseTestName := "testmsg.example.org." + responseTestMsg := new(dns.Msg) + responseTestMsg.SetQuestion(responseTestName, dns.TypeA) + + record.WriteMsg(responseTestMsg) + record.WriteMsg(responseTestMsg) + + if len(record.Msgs) != 2 { + t.Fatalf("Expected 2 messages to be written, but instead found %d\n", len(record.Msgs)) + + } + if record.Len != responseTestMsg.Len()*2 { + t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", responseTestMsg.Len()*2, record.Len) + } +} + +func TestMultiWrite(t *testing.T) { + w := &responseWriter{} + record := NewRecorder(w) + responseTest := []byte("testmsg.example.org.") + + record.Write(responseTest) + record.Write(responseTest) + if record.Len != len(responseTest)*2 { + t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(responseTest)*2, record.Len) + } +}