cache: do the msg copy right (#4207)
Not sure why this is proving so difficult.. pointers are hard? [Was tempted to rollback all tweaks here, but the original issue we're fixing it too important to not have a proper fix]. But we need to make a copy of the message at the earliest point in the handler because we are changing it (adding an opt rr). If we do this on the original message (which is a pointer) we change it (obvs). When undoing those changes we do work on a copy. Re: testing. There isn't a explicit test for this, so I've added on to the top-level test/ directory, which indeed makes the issue visible: master: ~~~ go test -v -run=TestLookupCacheWithoutEdns === RUN TestLookupCacheWithoutEdns cache_test.go:154: Expected no OPT RR, but got: ;; OPT PSEUDOSECTION: ; EDNS: version 0; flags: do; udp: 2048 --- FAIL: TestLookupCacheWithoutEdns (0.01s) FAIL ~~~ This branch: ~~~ % go test -v -run=TestLookupCacheWithoutEdns === RUN TestLookupCacheWithoutEdns --- PASS: TestLookupCacheWithoutEdns (0.01s) PASS ok github.com/coredns/coredns/test 0.109s ~~~ Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
parent
6938dac21d
commit
268781d355
3 changed files with 61 additions and 19 deletions
21
plugin/cache/cache.go
vendored
21
plugin/cache/cache.go
vendored
|
@ -143,15 +143,12 @@ func (w *ResponseWriter) RemoteAddr() net.Addr {
|
||||||
|
|
||||||
// WriteMsg implements the dns.ResponseWriter interface.
|
// WriteMsg implements the dns.ResponseWriter interface.
|
||||||
func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
|
func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
|
||||||
// res needs to be copied otherwise we will be modifying the underlaying arrays which are now cached.
|
mt, _ := response.Typify(res, w.now().UTC())
|
||||||
resc := res.Copy()
|
|
||||||
|
|
||||||
mt, _ := response.Typify(resc, w.now().UTC())
|
|
||||||
|
|
||||||
// key returns empty string for anything we don't want to cache.
|
// key returns empty string for anything we don't want to cache.
|
||||||
hasKey, key := key(w.state.Name(), resc, mt)
|
hasKey, key := key(w.state.Name(), res, mt)
|
||||||
|
|
||||||
msgTTL := dnsutil.MinimalTTL(resc, mt)
|
msgTTL := dnsutil.MinimalTTL(res, mt)
|
||||||
var duration time.Duration
|
var duration time.Duration
|
||||||
if mt == response.NameError || mt == response.NoData {
|
if mt == response.NameError || mt == response.NoData {
|
||||||
duration = computeTTL(msgTTL, w.minnttl, w.nttl)
|
duration = computeTTL(msgTTL, w.minnttl, w.nttl)
|
||||||
|
@ -163,8 +160,8 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasKey && duration > 0 {
|
if hasKey && duration > 0 {
|
||||||
if w.state.Match(resc) {
|
if w.state.Match(res) {
|
||||||
w.set(resc, key, mt, duration)
|
w.set(res, key, mt, duration)
|
||||||
cacheSize.WithLabelValues(w.server, Success).Set(float64(w.pcache.Len()))
|
cacheSize.WithLabelValues(w.server, Success).Set(float64(w.pcache.Len()))
|
||||||
cacheSize.WithLabelValues(w.server, Denial).Set(float64(w.ncache.Len()))
|
cacheSize.WithLabelValues(w.server, Denial).Set(float64(w.ncache.Len()))
|
||||||
} else {
|
} else {
|
||||||
|
@ -180,11 +177,11 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
|
||||||
// Apply capped TTL to this reply to avoid jarring TTL experience 1799 -> 8 (e.g.)
|
// Apply capped TTL to this reply to avoid jarring TTL experience 1799 -> 8 (e.g.)
|
||||||
// We also may need to filter out DNSSEC records, see toMsg() for similar code.
|
// We also may need to filter out DNSSEC records, see toMsg() for similar code.
|
||||||
ttl := uint32(duration.Seconds())
|
ttl := uint32(duration.Seconds())
|
||||||
resc.Answer = filterRRSlice(resc.Answer, ttl, w.do, false)
|
res.Answer = filterRRSlice(res.Answer, ttl, w.do, false)
|
||||||
resc.Ns = filterRRSlice(resc.Ns, ttl, w.do, false)
|
res.Ns = filterRRSlice(res.Ns, ttl, w.do, false)
|
||||||
resc.Extra = filterRRSlice(resc.Extra, ttl, w.do, false)
|
res.Extra = filterRRSlice(res.Extra, ttl, w.do, false)
|
||||||
|
|
||||||
return w.ResponseWriter.WriteMsg(resc)
|
return w.ResponseWriter.WriteMsg(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriter) set(m *dns.Msg, key uint64, mt response.Type, duration time.Duration) {
|
func (w *ResponseWriter) set(m *dns.Msg, key uint64, mt response.Type, duration time.Duration) {
|
||||||
|
|
14
plugin/cache/handler.go
vendored
14
plugin/cache/handler.go
vendored
|
@ -14,12 +14,13 @@ import (
|
||||||
|
|
||||||
// ServeDNS implements the plugin.Handler interface.
|
// ServeDNS implements the plugin.Handler interface.
|
||||||
func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
state := request.Request{W: w, Req: r}
|
rc := r.Copy() // We potentially modify r, to prevent other plugins from seeing this (r is a pointer), copy r into rc.
|
||||||
|
state := request.Request{W: w, Req: rc}
|
||||||
do := state.Do()
|
do := state.Do()
|
||||||
|
|
||||||
zone := plugin.Zones(c.Zones).Matches(state.Name())
|
zone := plugin.Zones(c.Zones).Matches(state.Name())
|
||||||
if zone == "" {
|
if zone == "" {
|
||||||
return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r)
|
return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
now := c.now().UTC()
|
now := c.now().UTC()
|
||||||
|
@ -39,22 +40,21 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
||||||
}
|
}
|
||||||
if i == nil {
|
if i == nil {
|
||||||
if !do {
|
if !do {
|
||||||
setDo(r)
|
setDo(rc)
|
||||||
}
|
}
|
||||||
crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server, do: do}
|
crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server, do: do}
|
||||||
return plugin.NextOrFailure(c.Name(), c.Next, ctx, crr, r)
|
return plugin.NextOrFailure(c.Name(), c.Next, ctx, crr, rc)
|
||||||
}
|
}
|
||||||
if ttl < 0 {
|
if ttl < 0 {
|
||||||
servedStale.WithLabelValues(server).Inc()
|
servedStale.WithLabelValues(server).Inc()
|
||||||
// Adjust the time to get a 0 TTL in the reply built from a stale item.
|
// Adjust the time to get a 0 TTL in the reply built from a stale item.
|
||||||
now = now.Add(time.Duration(ttl) * time.Second)
|
now = now.Add(time.Duration(ttl) * time.Second)
|
||||||
go func() {
|
go func() {
|
||||||
r := r.Copy()
|
|
||||||
if !do {
|
if !do {
|
||||||
setDo(r)
|
setDo(rc)
|
||||||
}
|
}
|
||||||
crr := &ResponseWriter{Cache: c, state: state, server: server, prefetch: true, remoteAddr: w.LocalAddr(), do: do}
|
crr := &ResponseWriter{Cache: c, state: state, server: server, prefetch: true, remoteAddr: w.LocalAddr(), do: do}
|
||||||
plugin.NextOrFailure(c.Name(), c.Next, ctx, crr, r)
|
plugin.NextOrFailure(c.Name(), c.Next, ctx, crr, rc)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
resp := i.toMsg(r, now, do)
|
resp := i.toMsg(r, now, do)
|
||||||
|
|
|
@ -110,3 +110,48 @@ func testCaseDNSSEC(t *testing.T, name, addr string, bufsize int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLookupCacheWithoutEdns(t *testing.T) {
|
||||||
|
name, rm, err := test.TempFile(".", exampleOrg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create zone: %s", err)
|
||||||
|
}
|
||||||
|
defer rm()
|
||||||
|
|
||||||
|
corefile := `example.org:0 {
|
||||||
|
file ` + name + `
|
||||||
|
}`
|
||||||
|
|
||||||
|
i, udp, _, err := CoreDNSServerAndPorts(corefile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||||
|
}
|
||||||
|
defer i.Stop()
|
||||||
|
|
||||||
|
// Start caching forward CoreDNS that we want to test.
|
||||||
|
corefile = `example.org:0 {
|
||||||
|
forward . ` + udp + `
|
||||||
|
cache 10
|
||||||
|
}`
|
||||||
|
|
||||||
|
i, udp, _, err = CoreDNSServerAndPorts(corefile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||||
|
}
|
||||||
|
defer i.Stop()
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("example.org.", dns.TypeA)
|
||||||
|
resp, err := dns.Exchange(m, udp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected to receive reply, but didn't: %s", err)
|
||||||
|
}
|
||||||
|
if len(resp.Extra) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Extra[0].Header().Rrtype == dns.TypeOPT {
|
||||||
|
t.Fatalf("Expected no OPT RR, but got: %s", resp.Extra[0])
|
||||||
|
}
|
||||||
|
t.Fatalf("Expected empty additional section, got %v", resp.Extra)
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue