diff --git a/plugin/cache/cache.go b/plugin/cache/cache.go index 32185de19..f5edc001b 100644 --- a/plugin/cache/cache.go +++ b/plugin/cache/cache.go @@ -142,12 +142,15 @@ func (w *ResponseWriter) RemoteAddr() net.Addr { // WriteMsg implements the dns.ResponseWriter interface. func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { - mt, _ := response.Typify(res, w.now().UTC()) + // res needs to be copied otherwise we will be modifying the underlaying arrays which are now cached. + resc := res.Copy() + + mt, _ := response.Typify(resc, w.now().UTC()) // key returns empty string for anything we don't want to cache. - hasKey, key := key(w.state.Name(), res, mt) + hasKey, key := key(w.state.Name(), resc, mt) - msgTTL := dnsutil.MinimalTTL(res, mt) + msgTTL := dnsutil.MinimalTTL(resc, mt) var duration time.Duration if mt == response.NameError || mt == response.NoData { duration = computeTTL(msgTTL, w.minnttl, w.nttl) @@ -159,8 +162,8 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { } if hasKey && duration > 0 { - if w.state.Match(res) { - w.set(res, key, mt, duration) + if w.state.Match(resc) { + w.set(resc, key, mt, duration) cacheSize.WithLabelValues(w.server, Success).Set(float64(w.pcache.Len())) cacheSize.WithLabelValues(w.server, Denial).Set(float64(w.ncache.Len())) } else { @@ -174,39 +177,14 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { } do := w.state.Do() - // Apply capped TTL to this reply to avoid jarring TTL experience 1799 -> 8 (e.g.) // We also may need to filter out DNSSEC records, see toMsg() for similar code. ttl := uint32(duration.Seconds()) - j := 0 - for _, r := range res.Answer { - if !do && isDNSSEC(r) { - continue - } - res.Answer[j].Header().Ttl = ttl - j++ - } - res.Answer = res.Answer[:j] - j = 0 - for _, r := range res.Ns { - if !do && isDNSSEC(r) { - continue - } - res.Ns[j].Header().Ttl = ttl - j++ - } - res.Ns = res.Ns[:j] - j = 0 - for _, r := range res.Extra { - if !do && isDNSSEC(r) { - continue - } - if res.Extra[j].Header().Rrtype != dns.TypeOPT { - res.Extra[j].Header().Ttl = ttl - } - j++ - } - return w.ResponseWriter.WriteMsg(res) + resc.Answer = filterRRSlice(resc.Answer, ttl, do, false) + resc.Ns = filterRRSlice(resc.Ns, ttl, do, false) + resc.Extra = filterRRSlice(resc.Extra, ttl, do, false) + + return w.ResponseWriter.WriteMsg(resc) } func (w *ResponseWriter) set(m *dns.Msg, key uint64, mt response.Type, duration time.Duration) { diff --git a/plugin/cache/cache_test.go b/plugin/cache/cache_test.go index 717276e66..9417a59f6 100644 --- a/plugin/cache/cache_test.go +++ b/plugin/cache/cache_test.go @@ -216,13 +216,13 @@ func TestCache(t *testing.T) { resp := i.toMsg(m, time.Now().UTC(), state.Do()) if err := test.Header(tc.Case, resp); err != nil { - t.Logf("Bla %v", resp) + t.Logf("Cache %v", resp) t.Error(err) continue } if err := test.Section(tc.Case, test.Answer, resp.Answer); err != nil { - t.Logf("Bla %v -- %v", test.Answer, resp.Answer) + t.Logf("Cache %v -- %v", test.Answer, resp.Answer) t.Error(err) } if err := test.Section(tc.Case, test.Ns, resp.Ns); err != nil { diff --git a/plugin/cache/dnssec.go b/plugin/cache/dnssec.go new file mode 100644 index 000000000..72520e345 --- /dev/null +++ b/plugin/cache/dnssec.go @@ -0,0 +1,43 @@ +package cache + +import "github.com/miekg/dns" + +// isDNSSEC returns true if r is a DNSSEC record. NSEC,NSEC3,DS and RRSIG/SIG +// are DNSSEC records. DNSKEYs is not in this list on the assumption that the +// client explictly asked for it. +func isDNSSEC(r dns.RR) bool { + switch r.Header().Rrtype { + case dns.TypeNSEC: + return true + case dns.TypeNSEC3: + return true + case dns.TypeDS: + return true + case dns.TypeRRSIG: + return true + case dns.TypeSIG: + return true + } + return false +} + +// filterRRSlice filters rrs and removes DNSSEC RRs when do is false. In the returned slice +// the TTLs are set to ttl. If dup is true the RRs in rrs are _copied_ into the slice that is +// returned. +func filterRRSlice(rrs []dns.RR, ttl uint32, do, dup bool) []dns.RR { + j := 0 + rs := make([]dns.RR, len(rrs), len(rrs)) + for _, r := range rrs { + if !do && isDNSSEC(r) { + continue + } + r.Header().Ttl = ttl + if dup { + rs[j] = dns.Copy(r) + } else { + rs[j] = r + } + j++ + } + return rs[:j] +} diff --git a/plugin/cache/dnssec_test.go b/plugin/cache/dnssec_test.go new file mode 100644 index 000000000..446718c9f --- /dev/null +++ b/plugin/cache/dnssec_test.go @@ -0,0 +1,112 @@ +package cache + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestResponseWithDNSSEC(t *testing.T) { + // We do 2 queries, one where we want non-dnssec and one with dnssec and check the responses in each of them + var tcs = []test.Case{ + { + Qname: "invent.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org."), + test.A("leptone.example.org. 1781 IN A 195.201.182.103"), + }, + }, + { + Qname: "invent.example.org.", Qtype: dns.TypeA, + Do: true, + Answer: []dns.RR{ + test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org."), + test.RRSIG("invent.example.org. 1781 IN RRSIG CNAME 8 3 1800 20201012085750 20200912082613 57411 example.org. ijSv5FmsNjFviBcOFwQgqjt073lttxTTNqkno6oMa3DD3kC+"), + test.A("leptone.example.org. 1781 IN A 195.201.182.103"), + test.RRSIG("leptone.example.org. 1781 IN RRSIG A 8 3 1800 20201012093630 20200912083827 57411 example.org. eLuSOkLAzm/WIOpaZD3/4TfvKP1HAFzjkis9LIJSRVpQt307dm9WY9"), + }, + }, + } + + c := New() + c.Next = dnssecHandler() + + for i, tc := range tcs { + m := tc.Msg() + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.ServeDNS(context.TODO(), rec, m) + if err := test.Section(tc, test.Answer, rec.Msg.Answer); err != nil { + t.Errorf("Test %d, expected no error, got %s", i, err) + } + } + + // now do the reverse + c = New() + c.Next = dnssecHandler() + + for i, tc := range []test.Case{tcs[1], tcs[0]} { + m := tc.Msg() + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.ServeDNS(context.TODO(), rec, m) + if err := test.Section(tc, test.Answer, rec.Msg.Answer); err != nil { + t.Errorf("Test %d, expected no error, got %s", i, err) + } + } +} + +func dnssecHandler() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + + m.Answer = make([]dns.RR, 4) + m.Answer[0] = test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org.") + m.Answer[1] = test.RRSIG("invent.example.org. 1781 IN RRSIG CNAME 8 3 1800 20201012085750 20200912082613 57411 example.org. ijSv5FmsNjFviBcOFwQgqjt073lttxTTNqkno6oMa3DD3kC+") + m.Answer[2] = test.A("leptone.example.org. 1781 IN A 195.201.182.103") + m.Answer[3] = test.RRSIG("leptone.example.org. 1781 IN RRSIG A 8 3 1800 20201012093630 20200912083827 57411 example.org. eLuSOkLAzm/WIOpaZD3/4TfvKP1HAFzjkis9LIJSRVpQt307dm9WY9") + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} + +func TestFliterRRSlice(t *testing.T) { + rrs := []dns.RR{ + test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org."), + test.RRSIG("invent.example.org. 1781 IN RRSIG CNAME 8 3 1800 20201012085750 20200912082613 57411 example.org. ijSv5FmsNjFviBcOFwQgqjt073lttxTTNqkno6oMa3DD3kC+"), + test.A("leptone.example.org. 1781 IN A 195.201.182.103"), + test.RRSIG("leptone.example.org. 1781 IN RRSIG A 8 3 1800 20201012093630 20200912083827 57411 example.org. eLuSOkLAzm/WIOpaZD3/4TfvKP1HAFzjkis9LIJSRVpQt307dm9WY9"), + } + + filter1 := filterRRSlice(rrs, 0, true, false) + if len(filter1) != 4 { + t.Errorf("Expected 4 RRs after filtering, got %d", len(filter1)) + } + rrsig := 0 + for _, f := range filter1 { + if f.Header().Rrtype == dns.TypeRRSIG { + rrsig++ + } + } + if rrsig != 2 { + t.Errorf("Expected 2 RRSIGs after filtering, got %d", rrsig) + } + + filter2 := filterRRSlice(rrs, 0, false, false) + if len(filter2) != 2 { + t.Errorf("Expected 2 RRs after filtering, got %d", len(filter2)) + } + rrsig = 0 + for _, f := range filter2 { + if f.Header().Rrtype == dns.TypeRRSIG { + rrsig++ + } + } + if rrsig != 0 { + t.Errorf("Expected 0 RRSIGs after filtering, got %d", rrsig) + } +} diff --git a/plugin/cache/do_test.go b/plugin/cache/do_test.go deleted file mode 100644 index 3cf87cabe..000000000 --- a/plugin/cache/do_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package cache - -import ( - "context" - "testing" - - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/plugin/test" - - "github.com/miekg/dns" -) - -func TestDo(t *testing.T) { - // cache sets Do and requests that don't have them. - c := New() - c.Next = echoHandler() - req := new(dns.Msg) - req.SetQuestion("example.org.", dns.TypeA) - rec := dnstest.NewRecorder(&test.ResponseWriter{}) - - // No DO set. - c.ServeDNS(context.TODO(), rec, req) - reply := rec.Msg - opt := reply.Extra[len(reply.Extra)-1] - if x, ok := opt.(*dns.OPT); !ok { - t.Fatalf("Expected OPT RR, got %T", x) - } - if !opt.(*dns.OPT).Do() { - t.Errorf("Expected DO bit to be set, got false") - } - if x := opt.(*dns.OPT).UDPSize(); x != defaultUDPBufSize { - t.Errorf("Expected %d bufsize, got %d", defaultUDPBufSize, x) - } - - // Do set - so left alone. - const mysize = defaultUDPBufSize * 2 - setDo(req) - // set bufsize to something else than default to see cache doesn't touch it - req.Extra[len(req.Extra)-1].(*dns.OPT).SetUDPSize(mysize) - c.ServeDNS(context.TODO(), rec, req) - reply = rec.Msg - opt = reply.Extra[len(reply.Extra)-1] - if x, ok := opt.(*dns.OPT); !ok { - t.Fatalf("Expected OPT RR, got %T", x) - } - if !opt.(*dns.OPT).Do() { - t.Errorf("Expected DO bit to be set, got false") - } - if x := opt.(*dns.OPT).UDPSize(); x != mysize { - t.Errorf("Expected %d bufsize, got %d", mysize, x) - } - - // edns0 set, but not DO, so _not_ left alone. - req.Extra[len(req.Extra)-1].(*dns.OPT).SetDo(false) - c.ServeDNS(context.TODO(), rec, req) - reply = rec.Msg - opt = reply.Extra[len(reply.Extra)-1] - if x, ok := opt.(*dns.OPT); !ok { - t.Fatalf("Expected OPT RR, got %T", x) - } - if !opt.(*dns.OPT).Do() { - t.Errorf("Expected DO bit to be set, got false") - } - if x := opt.(*dns.OPT).UDPSize(); x != defaultUDPBufSize { - t.Errorf("Expected %d bufsize, got %d", defaultUDPBufSize, x) - } -} - -func echoHandler() plugin.Handler { - return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - w.WriteMsg(r) - return dns.RcodeSuccess, nil - }) -} diff --git a/plugin/cache/item.go b/plugin/cache/item.go index 989d57bb0..bda5fe746 100644 --- a/plugin/cache/item.go +++ b/plugin/cache/item.go @@ -75,37 +75,10 @@ func (i *item) toMsg(m *dns.Msg, now time.Time, do bool) *dns.Msg { m1.Extra = make([]dns.RR, len(i.Extra)) ttl := uint32(i.ttl(now)) - j := 0 - for _, r := range i.Answer { - if !do && isDNSSEC(r) { - continue - } - m1.Answer[j] = dns.Copy(r) - m1.Answer[j].Header().Ttl = ttl - j++ - } - m1.Answer = m1.Answer[:j] - j = 0 - for _, r := range i.Ns { - if !do && isDNSSEC(r) { - continue - } - m1.Ns[j] = dns.Copy(r) - m1.Ns[j].Header().Ttl = ttl - j++ - } - m1.Ns = m1.Ns[:j] - // newItem skips OPT records, so we can just use i.Extra as is. - j = 0 - for _, r := range i.Extra { - if !do && isDNSSEC(r) { - continue - } - m1.Extra[j] = dns.Copy(r) - m1.Extra[j].Header().Ttl = ttl - j++ - } - m1.Extra = m1.Extra[:j] + m1.Answer = filterRRSlice(i.Answer, ttl, do, true) + m1.Ns = filterRRSlice(i.Ns, ttl, do, true) + m1.Extra = filterRRSlice(i.Extra, ttl, do, true) + return m1 } @@ -113,22 +86,3 @@ func (i *item) ttl(now time.Time) int { ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds()) return ttl } - -// isDNSSEC returns true if r is a DNSSEC record. NSEC,NSEC3,DS and RRSIG/SIG -// are DNSSEC records. DNSKEYs is not in this list on the assumption that the -// client explictly asked for it. -func isDNSSEC(r dns.RR) bool { - switch r.Header().Rrtype { - case dns.TypeNSEC: - return true - case dns.TypeNSEC3: - return true - case dns.TypeDS: - return true - case dns.TypeRRSIG: - return true - case dns.TypeSIG: - return true - } - return false -} diff --git a/test/cache_test.go b/test/cache_test.go index 9cc36696b..24bd7e051 100644 --- a/test/cache_test.go +++ b/test/cache_test.go @@ -46,6 +46,13 @@ func TestLookupCache(t *testing.T) { testCase(t, "short.example.org.", udp, 1, 5) }) + t.Run("DNSSEC OPT", func(t *testing.T) { + testCaseDNSSEC(t, "example.org.", udp, 4096) + }) + + t.Run("DNSSEC OPT", func(t *testing.T) { + testCaseDNSSEC(t, "example.org.", udp, 0) + }) } func testCase(t *testing.T, name, addr string, expectAnsLen int, expectTTL uint32) { @@ -53,7 +60,7 @@ func testCase(t *testing.T, name, addr string, expectAnsLen int, expectTTL uint3 m.SetQuestion(name, dns.TypeA) resp, err := dns.Exchange(m, addr) if err != nil { - t.Fatal("Expected to receive reply, but didn't") + t.Fatalf("Expected to receive reply, but didn't: %s", err) } if len(resp.Answer) != expectAnsLen { @@ -65,3 +72,41 @@ func testCase(t *testing.T, name, addr string, expectAnsLen int, expectTTL uint3 t.Errorf("Expected TTL to be %d, got %d", expectTTL, ttl) } } + +func testCaseDNSSEC(t *testing.T, name, addr string, bufsize int) { + m := new(dns.Msg) + m.SetQuestion(name, dns.TypeA) + + if bufsize > 0 { + o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}} + o.SetDo() + o.SetUDPSize(uint16(bufsize)) + m.Extra = append(m.Extra, o) + } + resp, err := dns.Exchange(m, addr) + if err != nil { + t.Fatalf("Expected to receive reply, but didn't: %s", err) + } + + if len(resp.Extra) == 0 && bufsize == 0 { + // no OPT, this is OK + return + } + + opt := resp.Extra[len(resp.Extra)-1] + if x, ok := opt.(*dns.OPT); !ok && bufsize > 0 { + t.Fatalf("Expected OPT RR, got %T", x) + } + if bufsize > 0 { + if !opt.(*dns.OPT).Do() { + t.Errorf("Expected DO bit to be set, got false") + } + if x := opt.(*dns.OPT).UDPSize(); int(x) != bufsize { + t.Errorf("Expected %d bufsize, got %d", bufsize, x) + } + } else { + if opt.Header().Rrtype == dns.TypeOPT { + t.Errorf("Expected no OPT RR, but got one: %s", opt) + } + } +}