diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index ba251d9ea..8a65dad3e 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -114,7 +114,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg child.Finish() } - ret, err = truncated(ret, err) + ret, err = truncated(state, ret, err) upstreamErr = err if err != nil { diff --git a/plugin/forward/lookup.go b/plugin/forward/lookup.go index 77df9b8a7..cbfed545e 100644 --- a/plugin/forward/lookup.go +++ b/plugin/forward/lookup.go @@ -33,7 +33,7 @@ func (f *Forward) Forward(state request.Request) (*dns.Msg, error) { ret, err := proxy.connect(context.Background(), state, f.forceTCP, true) - ret, err = truncated(ret, err) + ret, err = truncated(state, ret, err) upstreamErr = err if err != nil { diff --git a/plugin/forward/truncated.go b/plugin/forward/truncated.go index edd68fc0c..f9bd464d1 100644 --- a/plugin/forward/truncated.go +++ b/plugin/forward/truncated.go @@ -1,10 +1,14 @@ package forward -import "github.com/miekg/dns" +import ( + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) // truncated looks at the error and if truncated return a nil errror // and a possible reconstructed dns message if that was nil. -func truncated(ret *dns.Msg, err error) (*dns.Msg, error) { +func truncated(state request.Request, ret *dns.Msg, err error) (*dns.Msg, error) { // If you query for instance ANY isc.org; you get a truncated query back which miekg/dns fails to unpack // because the RRs are not finished. The returned message can be useful or useless. Return the original // query with some header bits set that they should retry with TCP. @@ -16,7 +20,7 @@ func truncated(ret *dns.Msg, err error) (*dns.Msg, error) { m := ret if ret == nil { m = new(dns.Msg) - m.SetReply(ret) + m.SetReply(state.Req) m.Truncated = true m.Authoritative = true m.Rcode = dns.RcodeSuccess diff --git a/plugin/forward/truncated_test.go b/plugin/forward/truncated_test.go new file mode 100644 index 000000000..4d8a0a25e --- /dev/null +++ b/plugin/forward/truncated_test.go @@ -0,0 +1,114 @@ +package forward + +import ( + "sync/atomic" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestLookupTruncated(t *testing.T) { + i := int32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + j := atomic.LoadInt32(&i) + atomic.AddInt32(&i, 1) + + if j == 0 { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Truncated = true + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + return + + } + + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr, nil /* no TLS */) + f := New() + f.SetProxy(p) + defer f.Close() + + state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + + resp, err := f.Lookup(state, "example.org.", dns.TypeA) + if err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + // expect answer with TC + if !resp.Truncated { + t.Error("Expected to receive reply with TC bit set, but didn't") + } + + resp, err = f.Lookup(state, "example.org.", dns.TypeA) + if err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + // expect answer without TC + if resp.Truncated { + t.Error("Expected to receive reply without TC bit set, but didn't") + } +} + +func TestForwardTruncated(t *testing.T) { + i := int32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + j := atomic.LoadInt32(&i) + atomic.AddInt32(&i, 1) + + if j == 0 { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Truncated = true + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + return + + } + + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + f := New() + + p1 := NewProxy(s.Addr, nil /* no TLS */) + f.SetProxy(p1) + p2 := NewProxy(s.Addr, nil /* no TLS */) + f.SetProxy(p2) + defer f.Close() + + state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + state.Req.SetQuestion("example.org.", dns.TypeA) + resp, err := f.Forward(state) + if err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + + // expect answer with TC + if !resp.Truncated { + t.Error("Expected to receive reply with TC bit set, but didn't") + } + + resp, err = f.Forward(state) + if err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + // expect answer without TC + if resp.Truncated { + t.Error("Expected to receive reply without TC bit set, but didn't") + } +}