diff --git a/middleware/edns.go b/middleware/edns.go new file mode 100644 index 000000000..aaab502e0 --- /dev/null +++ b/middleware/edns.go @@ -0,0 +1,34 @@ +package middleware + +import ( + "errors" + + "github.com/miekg/dns" +) + +// Edns0Version checks the EDNS version in the request. If error +// is nil everything is OK and we can invoke the middleware. If non-nil, the +// returned Msg is valid to be returned to the client (and should). For some +// reason this response should not contain a question RR in the question section. +func Edns0Version(req *dns.Msg) (*dns.Msg, error) { + opt := req.IsEdns0() + if opt == nil { + return nil, nil + } + if opt.Version() == 0 { + return nil, nil + } + m := new(dns.Msg) + m.SetReply(req) + // zero out question section, wtf. + m.Question = nil + + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + o.SetVersion(0) + o.SetExtendedRcode(dns.RcodeBadVers) + m.Extra = []dns.RR{o} + + return m, errors.New("EDNS0 BADVERS") +} diff --git a/middleware/edns_test.go b/middleware/edns_test.go new file mode 100644 index 000000000..7b4e6fc66 --- /dev/null +++ b/middleware/edns_test.go @@ -0,0 +1,37 @@ +package middleware + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestEdns0Version(t *testing.T) { + m := ednsMsg() + m.Extra[0].(*dns.OPT).SetVersion(2) + + _, err := Edns0Version(m) + if err == nil { + t.Errorf("expected wrong version, but got OK") + } +} + +func TestEdns0VersionNoEdns(t *testing.T) { + m := ednsMsg() + m.Extra = nil + + _, err := Edns0Version(m) + if err != nil { + t.Errorf("expected no error, but got one: %s", err) + } +} + +func ednsMsg() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + m.Extra = append(m.Extra, o) + return m +} diff --git a/middleware/rcode.go b/middleware/rcode.go new file mode 100644 index 000000000..989f90fdd --- /dev/null +++ b/middleware/rcode.go @@ -0,0 +1,14 @@ +package middleware + +import ( + "strconv" + + "github.com/miekg/dns" +) + +func RcodeToString(rcode int) string { + if str, ok := dns.RcodeToString[rcode]; ok { + return str + } + return "RCODE" + strconv.Itoa(rcode) +} diff --git a/middleware/recorder.go b/middleware/recorder.go index feede34ae..d1e466ec3 100644 --- a/middleware/recorder.go +++ b/middleware/recorder.go @@ -1,7 +1,6 @@ package middleware import ( - "strconv" "time" "github.com/miekg/dns" @@ -54,27 +53,16 @@ func (r *ResponseRecorder) Write(buf []byte) (int, error) { } // Size returns the size. -func (r *ResponseRecorder) Size() int { - return r.size -} +func (r *ResponseRecorder) Size() int { return r.size } // Rcode returns the rcode. -func (r *ResponseRecorder) Rcode() string { - if rcode, ok := dns.RcodeToString[r.rcode]; ok { - return rcode - } - return "RCODE" + strconv.Itoa(r.rcode) -} +func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) } // Start returns the start time of the ResponseRecorder. -func (r *ResponseRecorder) Start() time.Time { - return r.start -} +func (r *ResponseRecorder) Start() time.Time { return r.start } // Msg returns the written message from the ResponseRecorder. -func (r *ResponseRecorder) Msg() *dns.Msg { - return r.msg -} +func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg } // Hijack implements dns.Hijacker. It simply wraps the underlying // ResponseWriter's Hijack method if there is one, or returns an error. diff --git a/server/server.go b/server/server.go index 67cc35ba5..7ea931daa 100644 --- a/server/server.go +++ b/server/server.go @@ -12,10 +12,10 @@ import ( "net" "os" "runtime" - "strconv" "sync" "time" + "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/chaos" "github.com/miekg/coredns/middleware/prometheus" @@ -279,6 +279,14 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } }() + if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once. + qtype := dns.Type(r.Question[0].Qtype).String() + rc := middleware.RcodeToString(dns.RcodeBadVers) + metrics.Report(dropped, qtype, rc, m.Len(), time.Now()) + w.WriteMsg(m) + return + } + // Execute the optional request callback if it exists if s.ReqCallback != nil && s.ReqCallback(w, r) { return @@ -332,12 +340,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // of the specified HTTP status code. func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { qtype := dns.Type(r.Question[0].Qtype).String() - - // this code is duplicated a few times, TODO(miek) - rc := dns.RcodeToString[rcode] - if rc == "" { - rc = "RCODE" + strconv.Itoa(rcode) - } + rc := middleware.RcodeToString(rcode) answer := new(dns.Msg) answer.SetRcode(r, rcode)