dnstest: add multirecorder (#1326)

* dnstest: add multirecorder

This adds a new recorder that captures all messages written to it. This
can be useful when, for instance, testing AXFR which can write muliple
messages back to the client.

* docs
This commit is contained in:
Miek Gieben 2017-12-22 08:54:27 +00:00 committed by GitHub
parent 08076e5284
commit 1f81d154ed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 0 deletions

View file

@ -0,0 +1,41 @@
package dnstest
import (
"time"
"github.com/miekg/dns"
)
// MultiRecorder is a type of ResponseWriter that captures all messages written to it.
type MultiRecorder struct {
Len int
Msgs []*dns.Msg
Start time.Time
dns.ResponseWriter
}
// NewMultiRecorder makes and returns a new MultiRecorder.
func NewMultiRecorder(w dns.ResponseWriter) *MultiRecorder {
return &MultiRecorder{
ResponseWriter: w,
Msgs: make([]*dns.Msg, 0),
Start: time.Now(),
}
}
// WriteMsg records the message and its length written to it and call the
// underlying ResponseWriter's WriteMsg method.
func (r *MultiRecorder) WriteMsg(res *dns.Msg) error {
r.Len += res.Len()
r.Msgs = append(r.Msgs, res)
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the length of the messages that get written to it.
func (r *MultiRecorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.Len += n
}
return n, err
}

View file

@ -0,0 +1,39 @@
package dnstest
import (
"testing"
"github.com/miekg/dns"
)
func TestMultiWriteMsg(t *testing.T) {
w := &responseWriter{}
record := NewMultiRecorder(w)
responseTestName := "testmsg.example.org."
responseTestMsg := new(dns.Msg)
responseTestMsg.SetQuestion(responseTestName, dns.TypeA)
record.WriteMsg(responseTestMsg)
record.WriteMsg(responseTestMsg)
if len(record.Msgs) != 2 {
t.Fatalf("Expected 2 messages to be written, but instead found %d\n", len(record.Msgs))
}
if record.Len != responseTestMsg.Len()*2 {
t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", responseTestMsg.Len()*2, record.Len)
}
}
func TestMultiWrite(t *testing.T) {
w := &responseWriter{}
record := NewRecorder(w)
responseTest := []byte("testmsg.example.org.")
record.Write(responseTest)
record.Write(responseTest)
if record.Len != len(responseTest)*2 {
t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(responseTest)*2, record.Len)
}
}