diff --git a/plugin/erratic/erratic.go b/plugin/erratic/erratic.go index 3460f3bca..f60e605d1 100644 --- a/plugin/erratic/erratic.go +++ b/plugin/erratic/erratic.go @@ -19,6 +19,7 @@ type Erratic struct { duration time.Duration truncate uint64 + large bool // undocumented feature; return large responses for A request (>512B, to test compression). q uint64 // counter of queries } @@ -57,6 +58,11 @@ func (e *Erratic) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg rr := *(rrA.(*dns.A)) rr.Header().Name = state.QName() m.Answer = append(m.Answer, &rr) + if e.large { + for i := 0; i < 29; i++ { + m.Answer = append(m.Answer, &rr) + } + } case dns.TypeAAAA: rr := *(rrAAAA.(*dns.AAAA)) rr.Header().Name = state.QName() diff --git a/plugin/erratic/erratic_test.go b/plugin/erratic/erratic_test.go index 406fd8774..ec2ec5c0a 100644 --- a/plugin/erratic/erratic_test.go +++ b/plugin/erratic/erratic_test.go @@ -98,3 +98,19 @@ func TestAxfr(t *testing.T) { t.Errorf("Expected for record to be %d, got %d", dns.TypeSOA, x) } } + +func TestErratic(t *testing.T) { + e := &Erratic{drop: 0, delay: 0} + + ctx := context.TODO() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + e.ServeDNS(ctx, rec, req) + + if rec.Msg.Answer[0].Header().Rrtype != dns.TypeA { + t.Errorf("Expected A response, got %d type", rec.Msg.Answer[0].Header().Rrtype) + } +} diff --git a/plugin/erratic/setup.go b/plugin/erratic/setup.go index 52c4d245c..79e4449ee 100644 --- a/plugin/erratic/setup.go +++ b/plugin/erratic/setup.go @@ -104,6 +104,8 @@ func parseErratic(c *caddy.Controller) (*Erratic, error) { return nil, fmt.Errorf("illegal amount value given %q", args[0]) } e.truncate = uint64(amount) + case "large": + e.large = true default: return nil, c.Errf("unknown property '%s'", c.Val()) } diff --git a/request/request.go b/request/request.go index c4e4eea3c..f560612c8 100644 --- a/request/request.go +++ b/request/request.go @@ -226,11 +226,7 @@ func (r *Request) SizeAndDo(m *dns.Msg) bool { return true } -// Scrub is a noop function, added for backwards compatibility reasons. The original Scrub is now called -// automatically by the server on writing the reply. See ScrubWriter. -func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, int) { return reply, 0 } - -// scrub scrubs the reply message so that it will fit the client's buffer. It will first +// Scrub scrubs the reply message so that it will fit the client's buffer. It will first // check if the reply fits without compression and then *with* compression. // Scrub will then use binary search to find a save cut off point in the additional section. // If even *without* the additional section the reply still doesn't fit we @@ -238,7 +234,7 @@ func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, int) { return reply, 0 } // we set the TC bit on the reply; indicating the client should retry over TCP. // Note, the TC bit will be set regardless of protocol, even TCP message will // get the bit, the client should then retry with pigeons. -func (r *Request) scrub(reply *dns.Msg) *dns.Msg { +func (r *Request) Scrub(reply *dns.Msg) *dns.Msg { size := r.Size() reply.Compress = false diff --git a/request/request_test.go b/request/request_test.go index 6685ad3b3..bfc95bc5e 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -73,7 +73,7 @@ func TestRequestScrubAnswer(t *testing.T) { fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) } - req.scrub(reply) + req.Scrub(reply) if want, got := req.Size(), reply.Len(); want < got { t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) } @@ -94,7 +94,7 @@ func TestRequestScrubExtra(t *testing.T) { fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) } - req.scrub(reply) + req.Scrub(reply) if want, got := req.Size(), reply.Len(); want < got { t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) } @@ -116,7 +116,7 @@ func TestRequestScrubExtraEdns0(t *testing.T) { fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) } - req.scrub(reply) + req.Scrub(reply) if want, got := req.Size(), reply.Len(); want < got { t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) } @@ -146,7 +146,7 @@ func TestRequestScrubExtraRegression(t *testing.T) { fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i))) } - reply = req.scrub(reply) + reply = req.Scrub(reply) if want, got := req.Size(), reply.Len(); want < got { t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) } @@ -171,7 +171,7 @@ func TestRequestScrubAnswerExact(t *testing.T) { reply.Answer = append(reply.Answer, test.A(fmt.Sprintf("large.example.com. 10 IN A 127.0.0.%d", i))) } - req.scrub(reply) + req.Scrub(reply) if want, got := req.Size(), reply.Len(); want < got { t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) } diff --git a/request/writer.go b/request/writer.go index ef0c14417..ffbbe93e3 100644 --- a/request/writer.go +++ b/request/writer.go @@ -15,6 +15,6 @@ func NewScrubWriter(req *dns.Msg, w dns.ResponseWriter) *ScrubWriter { return &S // scrub on the message m and will then write it to the client. func (s *ScrubWriter) WriteMsg(m *dns.Msg) error { state := Request{Req: s.req, W: s.ResponseWriter} - new, _ := state.Scrub(m) - return s.ResponseWriter.WriteMsg(new) + n := state.Scrub(m) + return s.ResponseWriter.WriteMsg(n) } diff --git a/test/compression_scrub_test.go b/test/compression_scrub_test.go new file mode 100644 index 000000000..b18f1fe0a --- /dev/null +++ b/test/compression_scrub_test.go @@ -0,0 +1,60 @@ +package test + +import ( + "net" + "testing" + + "github.com/miekg/dns" +) + +func TestCompressScrub(t *testing.T) { + corefile := `example.org:0 { + erratic { + drop 0 + delay 0 + large + } + }` + + i, udp, _, err := CoreDNSServerAndPorts(corefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer i.Stop() + + c, err := net.Dial("udp", udp) + if err != nil { + t.Fatalf("Could not dial %s", err) + } + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + q, _ := m.Pack() + + c.Write(q) + buf := make([]byte, 1024) + n, err := c.Read(buf) + if err != nil || n == 0 { + t.Errorf("Expected reply, got: %s", err) + return + } + if n >= 512 { + t.Fatalf("Expected returned packet to be < 512, got %d", n) + } + buf = buf[:n] + // If there is compression in the returned packet we should look for compression pointers, if found + // the pointers should return to the domain name in the query (the first domain name that's avaiable for + // compression. This means we're looking for a combo where the pointers is detected and the offset is 12 + // the position of the first name after the header. The erratic plugin adds 30 RRs that should all be compressed. + found := 0 + for i := 0; i < len(buf)-1; i++ { + if buf[i]&0xC0 == 0xC0 { + off := (int(buf[i])^0xC0)<<8 | int(buf[i+1]) + if off == 12 { + found++ + } + } + } + if found != 30 { + t.Errorf("Failed to find all compression pointers in the packet, wanted 30, got %d", found) + } +}