EDNS: return error on wrong version. (#95)
Split up the previous changes a bit. This PR only returns the expected error when the received packet has the wrong EDNS version. EDNS0 handling in the middleware needs a nicer abstraction, like ReflectEdns() or something.
This commit is contained in:
parent
16c035731c
commit
db3d689a8a
5 changed files with 99 additions and 23 deletions
34
middleware/edns.go
Normal file
34
middleware/edns.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Edns0Version checks the EDNS version in the request. If error
|
||||
// is nil everything is OK and we can invoke the middleware. If non-nil, the
|
||||
// returned Msg is valid to be returned to the client (and should). For some
|
||||
// reason this response should not contain a question RR in the question section.
|
||||
func Edns0Version(req *dns.Msg) (*dns.Msg, error) {
|
||||
opt := req.IsEdns0()
|
||||
if opt == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if opt.Version() == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
// zero out question section, wtf.
|
||||
m.Question = nil
|
||||
|
||||
o := new(dns.OPT)
|
||||
o.Hdr.Name = "."
|
||||
o.Hdr.Rrtype = dns.TypeOPT
|
||||
o.SetVersion(0)
|
||||
o.SetExtendedRcode(dns.RcodeBadVers)
|
||||
m.Extra = []dns.RR{o}
|
||||
|
||||
return m, errors.New("EDNS0 BADVERS")
|
||||
}
|
37
middleware/edns_test.go
Normal file
37
middleware/edns_test.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestEdns0Version(t *testing.T) {
|
||||
m := ednsMsg()
|
||||
m.Extra[0].(*dns.OPT).SetVersion(2)
|
||||
|
||||
_, err := Edns0Version(m)
|
||||
if err == nil {
|
||||
t.Errorf("expected wrong version, but got OK")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEdns0VersionNoEdns(t *testing.T) {
|
||||
m := ednsMsg()
|
||||
m.Extra = nil
|
||||
|
||||
_, err := Edns0Version(m)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, but got one: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ednsMsg() *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := new(dns.OPT)
|
||||
o.Hdr.Name = "."
|
||||
o.Hdr.Rrtype = dns.TypeOPT
|
||||
m.Extra = append(m.Extra, o)
|
||||
return m
|
||||
}
|
14
middleware/rcode.go
Normal file
14
middleware/rcode.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func RcodeToString(rcode int) string {
|
||||
if str, ok := dns.RcodeToString[rcode]; ok {
|
||||
return str
|
||||
}
|
||||
return "RCODE" + strconv.Itoa(rcode)
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -54,27 +53,16 @@ func (r *ResponseRecorder) Write(buf []byte) (int, error) {
|
|||
}
|
||||
|
||||
// Size returns the size.
|
||||
func (r *ResponseRecorder) Size() int {
|
||||
return r.size
|
||||
}
|
||||
func (r *ResponseRecorder) Size() int { return r.size }
|
||||
|
||||
// Rcode returns the rcode.
|
||||
func (r *ResponseRecorder) Rcode() string {
|
||||
if rcode, ok := dns.RcodeToString[r.rcode]; ok {
|
||||
return rcode
|
||||
}
|
||||
return "RCODE" + strconv.Itoa(r.rcode)
|
||||
}
|
||||
func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) }
|
||||
|
||||
// Start returns the start time of the ResponseRecorder.
|
||||
func (r *ResponseRecorder) Start() time.Time {
|
||||
return r.start
|
||||
}
|
||||
func (r *ResponseRecorder) Start() time.Time { return r.start }
|
||||
|
||||
// Msg returns the written message from the ResponseRecorder.
|
||||
func (r *ResponseRecorder) Msg() *dns.Msg {
|
||||
return r.msg
|
||||
}
|
||||
func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg }
|
||||
|
||||
// Hijack implements dns.Hijacker. It simply wraps the underlying
|
||||
// ResponseWriter's Hijack method if there is one, or returns an error.
|
||||
|
|
|
@ -12,10 +12,10 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/coredns/middleware/chaos"
|
||||
"github.com/miekg/coredns/middleware/prometheus"
|
||||
|
||||
|
@ -279,6 +279,14 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||
}
|
||||
}()
|
||||
|
||||
if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once.
|
||||
qtype := dns.Type(r.Question[0].Qtype).String()
|
||||
rc := middleware.RcodeToString(dns.RcodeBadVers)
|
||||
metrics.Report(dropped, qtype, rc, m.Len(), time.Now())
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute the optional request callback if it exists
|
||||
if s.ReqCallback != nil && s.ReqCallback(w, r) {
|
||||
return
|
||||
|
@ -332,12 +340,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||
// of the specified HTTP status code.
|
||||
func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
|
||||
qtype := dns.Type(r.Question[0].Qtype).String()
|
||||
|
||||
// this code is duplicated a few times, TODO(miek)
|
||||
rc := dns.RcodeToString[rcode]
|
||||
if rc == "" {
|
||||
rc = "RCODE" + strconv.Itoa(rcode)
|
||||
}
|
||||
rc := middleware.RcodeToString(rcode)
|
||||
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(r, rcode)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue