diff --git a/plugin/nsid/nsid.go b/plugin/nsid/nsid.go index b79df75bb..e2506b45f 100644 --- a/plugin/nsid/nsid.go +++ b/plugin/nsid/nsid.go @@ -19,7 +19,8 @@ type Nsid struct { // ResponseWriter is a response writer that adds NSID response type ResponseWriter struct { dns.ResponseWriter - Data string + Data string + request *dns.Msg } // ServeDNS implements the plugin.Handler interface. @@ -27,7 +28,7 @@ func (n Nsid) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i if option := r.IsEdns0(); option != nil { for _, o := range option.Option { if _, ok := o.(*dns.EDNS0_NSID); ok { - nw := &ResponseWriter{ResponseWriter: w, Data: n.Data} + nw := &ResponseWriter{ResponseWriter: w, Data: n.Data, request: r} return plugin.NextOrFailure(n.Name(), n.Next, ctx, nw, r) } } @@ -37,16 +38,31 @@ func (n Nsid) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i // WriteMsg implements the dns.ResponseWriter interface. func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { + if w.request.IsEdns0() != nil && res.IsEdns0() == nil { + res.SetEdns0(w.request.IsEdns0().UDPSize(), true) + } + if option := res.IsEdns0(); option != nil { + var exists bool + for _, o := range option.Option { if e, ok := o.(*dns.EDNS0_NSID); ok { e.Code = dns.EDNS0NSID e.Nsid = hex.EncodeToString([]byte(w.Data)) + exists = true } } + + // Append the NSID if it doesn't exist in EDNS0 options + if !exists { + option.Option = append(option.Option, &dns.EDNS0_NSID{ + Code: dns.EDNS0NSID, + Nsid: hex.EncodeToString([]byte(w.Data)), + }) + } } - returned := w.ResponseWriter.WriteMsg(res) - return returned + + return w.ResponseWriter.WriteMsg(res) } // Name implements the Handler interface. diff --git a/plugin/nsid/nsid_test.go b/plugin/nsid/nsid_test.go index 32e8d8d59..04c21fa0a 100644 --- a/plugin/nsid/nsid_test.go +++ b/plugin/nsid/nsid_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/cache" "github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/whoami" @@ -71,3 +72,65 @@ func TestNsid(t *testing.T) { } } } + +func TestNsidCache(t *testing.T) { + em := Nsid{ + Data: "NSID", + } + c := cache.New() + + tests := []struct { + next plugin.Handler + qname string + qtype uint16 + expectedCode int + expectedReply string + expectedErr error + }{ + { + next: whoami.Whoami{}, + qname: ".", + expectedCode: dns.RcodeSuccess, + expectedReply: hex.EncodeToString([]byte("NSID")), + expectedErr: nil, + }, + } + + ctx := context.TODO() + + for i, tc := range tests { + req := new(dns.Msg) + if tc.qtype == 0 { + tc.qtype = dns.TypeA + } + req.SetQuestion(dns.Fqdn(tc.qname), tc.qtype) + req.Question[0].Qclass = dns.ClassINET + + req.SetEdns0(4096, false) + option := req.Extra[0].(*dns.OPT) + option.Option = append(option.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}) + em.Next = tc.next + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + c.Next = em + code, err := c.ServeDNS(ctx, rec, req) + + if err != tc.expectedErr { + t.Errorf("Test %d: Expected error %v, but got %v", i, tc.expectedErr, err) + } + if code != int(tc.expectedCode) { + t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) + } + if tc.expectedReply != "" { + for _, extra := range rec.Msg.Extra { + if option, ok := extra.(*dns.OPT); ok { + e := option.Option[0].(*dns.EDNS0_NSID) + if e.Nsid != tc.expectedReply { + t.Errorf("Test %d: Expected answer %s, but got %s", i, tc.expectedReply, e.Nsid) + } + } + } + } + } +}