diff --git a/plugin/cache/cache.go b/plugin/cache/cache.go index 8dae9a42b..2a56500a3 100644 --- a/plugin/cache/cache.go +++ b/plugin/cache/cache.go @@ -143,15 +143,12 @@ func (w *ResponseWriter) RemoteAddr() net.Addr { // WriteMsg implements the dns.ResponseWriter interface. func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { - // res needs to be copied otherwise we will be modifying the underlaying arrays which are now cached. - resc := res.Copy() - - mt, _ := response.Typify(resc, w.now().UTC()) + mt, _ := response.Typify(res, w.now().UTC()) // key returns empty string for anything we don't want to cache. - hasKey, key := key(w.state.Name(), resc, mt) + hasKey, key := key(w.state.Name(), res, mt) - msgTTL := dnsutil.MinimalTTL(resc, mt) + msgTTL := dnsutil.MinimalTTL(res, mt) var duration time.Duration if mt == response.NameError || mt == response.NoData { duration = computeTTL(msgTTL, w.minnttl, w.nttl) @@ -163,8 +160,8 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { } if hasKey && duration > 0 { - if w.state.Match(resc) { - w.set(resc, key, mt, duration) + if w.state.Match(res) { + w.set(res, key, mt, duration) cacheSize.WithLabelValues(w.server, Success).Set(float64(w.pcache.Len())) cacheSize.WithLabelValues(w.server, Denial).Set(float64(w.ncache.Len())) } else { @@ -180,11 +177,11 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { // Apply capped TTL to this reply to avoid jarring TTL experience 1799 -> 8 (e.g.) // We also may need to filter out DNSSEC records, see toMsg() for similar code. ttl := uint32(duration.Seconds()) - resc.Answer = filterRRSlice(resc.Answer, ttl, w.do, false) - resc.Ns = filterRRSlice(resc.Ns, ttl, w.do, false) - resc.Extra = filterRRSlice(resc.Extra, ttl, w.do, false) + res.Answer = filterRRSlice(res.Answer, ttl, w.do, false) + res.Ns = filterRRSlice(res.Ns, ttl, w.do, false) + res.Extra = filterRRSlice(res.Extra, ttl, w.do, false) - return w.ResponseWriter.WriteMsg(resc) + return w.ResponseWriter.WriteMsg(res) } func (w *ResponseWriter) set(m *dns.Msg, key uint64, mt response.Type, duration time.Duration) { diff --git a/plugin/cache/handler.go b/plugin/cache/handler.go index 987dd61b2..406ece8e6 100644 --- a/plugin/cache/handler.go +++ b/plugin/cache/handler.go @@ -14,12 +14,13 @@ import ( // ServeDNS implements the plugin.Handler interface. func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := request.Request{W: w, Req: r} + rc := r.Copy() // We potentially modify r, to prevent other plugins from seeing this (r is a pointer), copy r into rc. + state := request.Request{W: w, Req: rc} do := state.Do() zone := plugin.Zones(c.Zones).Matches(state.Name()) if zone == "" { - return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r) + return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, rc) } now := c.now().UTC() @@ -39,22 +40,21 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) } if i == nil { if !do { - setDo(r) + setDo(rc) } crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server, do: do} - return plugin.NextOrFailure(c.Name(), c.Next, ctx, crr, r) + return plugin.NextOrFailure(c.Name(), c.Next, ctx, crr, rc) } if ttl < 0 { servedStale.WithLabelValues(server).Inc() // Adjust the time to get a 0 TTL in the reply built from a stale item. now = now.Add(time.Duration(ttl) * time.Second) go func() { - r := r.Copy() if !do { - setDo(r) + setDo(rc) } crr := &ResponseWriter{Cache: c, state: state, server: server, prefetch: true, remoteAddr: w.LocalAddr(), do: do} - plugin.NextOrFailure(c.Name(), c.Next, ctx, crr, r) + plugin.NextOrFailure(c.Name(), c.Next, ctx, crr, rc) }() } resp := i.toMsg(r, now, do) diff --git a/test/cache_test.go b/test/cache_test.go index 24bd7e051..831c39fa5 100644 --- a/test/cache_test.go +++ b/test/cache_test.go @@ -110,3 +110,48 @@ func testCaseDNSSEC(t *testing.T, name, addr string, bufsize int) { } } } + +func TestLookupCacheWithoutEdns(t *testing.T) { + name, rm, err := test.TempFile(".", exampleOrg) + if err != nil { + t.Fatalf("Failed to create zone: %s", err) + } + defer rm() + + corefile := `example.org:0 { + file ` + name + ` + }` + + i, udp, _, err := CoreDNSServerAndPorts(corefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer i.Stop() + + // Start caching forward CoreDNS that we want to test. + corefile = `example.org:0 { + forward . ` + udp + ` + cache 10 + }` + + i, udp, _, err = CoreDNSServerAndPorts(corefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer i.Stop() + + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + resp, err := dns.Exchange(m, udp) + if err != nil { + t.Fatalf("Expected to receive reply, but didn't: %s", err) + } + if len(resp.Extra) == 0 { + return + } + + if resp.Extra[0].Header().Rrtype == dns.TypeOPT { + t.Fatalf("Expected no OPT RR, but got: %s", resp.Extra[0]) + } + t.Fatalf("Expected empty additional section, got %v", resp.Extra) +}