EDNS: return error on wrong version. (#95)

Split up the previous changes a bit. This PR only returns the expected
error when the received packet has the wrong EDNS version.

EDNS0 handling in the middleware needs a nicer abstraction, like
ReflectEdns() or something.
This commit is contained in:
Miek Gieben 2016-04-09 11:13:04 +01:00
parent 16c035731c
commit db3d689a8a
5 changed files with 99 additions and 23 deletions

34
middleware/edns.go Normal file
View file

@ -0,0 +1,34 @@
package middleware
import (
"errors"
"github.com/miekg/dns"
)
// Edns0Version checks the EDNS version in the request. If error
// is nil everything is OK and we can invoke the middleware. If non-nil, the
// returned Msg is valid to be returned to the client (and should). For some
// reason this response should not contain a question RR in the question section.
func Edns0Version(req *dns.Msg) (*dns.Msg, error) {
opt := req.IsEdns0()
if opt == nil {
return nil, nil
}
if opt.Version() == 0 {
return nil, nil
}
m := new(dns.Msg)
m.SetReply(req)
// zero out question section, wtf.
m.Question = nil
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetVersion(0)
o.SetExtendedRcode(dns.RcodeBadVers)
m.Extra = []dns.RR{o}
return m, errors.New("EDNS0 BADVERS")
}

37
middleware/edns_test.go Normal file
View file

@ -0,0 +1,37 @@
package middleware
import (
"testing"
"github.com/miekg/dns"
)
func TestEdns0Version(t *testing.T) {
m := ednsMsg()
m.Extra[0].(*dns.OPT).SetVersion(2)
_, err := Edns0Version(m)
if err == nil {
t.Errorf("expected wrong version, but got OK")
}
}
func TestEdns0VersionNoEdns(t *testing.T) {
m := ednsMsg()
m.Extra = nil
_, err := Edns0Version(m)
if err != nil {
t.Errorf("expected no error, but got one: %s", err)
}
}
func ednsMsg() *dns.Msg {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
m.Extra = append(m.Extra, o)
return m
}

14
middleware/rcode.go Normal file
View file

@ -0,0 +1,14 @@
package middleware
import (
"strconv"
"github.com/miekg/dns"
)
func RcodeToString(rcode int) string {
if str, ok := dns.RcodeToString[rcode]; ok {
return str
}
return "RCODE" + strconv.Itoa(rcode)
}

View file

@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"strconv"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -54,27 +53,16 @@ func (r *ResponseRecorder) Write(buf []byte) (int, error) {
} }
// Size returns the size. // Size returns the size.
func (r *ResponseRecorder) Size() int { func (r *ResponseRecorder) Size() int { return r.size }
return r.size
}
// Rcode returns the rcode. // Rcode returns the rcode.
func (r *ResponseRecorder) Rcode() string { func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) }
if rcode, ok := dns.RcodeToString[r.rcode]; ok {
return rcode
}
return "RCODE" + strconv.Itoa(r.rcode)
}
// Start returns the start time of the ResponseRecorder. // Start returns the start time of the ResponseRecorder.
func (r *ResponseRecorder) Start() time.Time { func (r *ResponseRecorder) Start() time.Time { return r.start }
return r.start
}
// Msg returns the written message from the ResponseRecorder. // Msg returns the written message from the ResponseRecorder.
func (r *ResponseRecorder) Msg() *dns.Msg { func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg }
return r.msg
}
// Hijack implements dns.Hijacker. It simply wraps the underlying // Hijack implements dns.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error. // ResponseWriter's Hijack method if there is one, or returns an error.

View file

@ -12,10 +12,10 @@ import (
"net" "net"
"os" "os"
"runtime" "runtime"
"strconv"
"sync" "sync"
"time" "time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/chaos" "github.com/miekg/coredns/middleware/chaos"
"github.com/miekg/coredns/middleware/prometheus" "github.com/miekg/coredns/middleware/prometheus"
@ -279,6 +279,14 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
}() }()
if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once.
qtype := dns.Type(r.Question[0].Qtype).String()
rc := middleware.RcodeToString(dns.RcodeBadVers)
metrics.Report(dropped, qtype, rc, m.Len(), time.Now())
w.WriteMsg(m)
return
}
// Execute the optional request callback if it exists // Execute the optional request callback if it exists
if s.ReqCallback != nil && s.ReqCallback(w, r) { if s.ReqCallback != nil && s.ReqCallback(w, r) {
return return
@ -332,12 +340,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// of the specified HTTP status code. // of the specified HTTP status code.
func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
qtype := dns.Type(r.Question[0].Qtype).String() qtype := dns.Type(r.Question[0].Qtype).String()
rc := middleware.RcodeToString(rcode)
// this code is duplicated a few times, TODO(miek)
rc := dns.RcodeToString[rcode]
if rc == "" {
rc = "RCODE" + strconv.Itoa(rcode)
}
answer := new(dns.Msg) answer := new(dns.Msg)
answer.SetRcode(r, rcode) answer.SetRcode(r, rcode)