EDNS: return error on wrong version. (#95)
Split up the previous changes a bit. This PR only returns the expected error when the received packet has the wrong EDNS version. EDNS0 handling in the middleware needs a nicer abstraction, like ReflectEdns() or something.
This commit is contained in:
parent
16c035731c
commit
db3d689a8a
5 changed files with 99 additions and 23 deletions
34
middleware/edns.go
Normal file
34
middleware/edns.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Edns0Version checks the EDNS version in the request. If error
|
||||||
|
// is nil everything is OK and we can invoke the middleware. If non-nil, the
|
||||||
|
// returned Msg is valid to be returned to the client (and should). For some
|
||||||
|
// reason this response should not contain a question RR in the question section.
|
||||||
|
func Edns0Version(req *dns.Msg) (*dns.Msg, error) {
|
||||||
|
opt := req.IsEdns0()
|
||||||
|
if opt == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if opt.Version() == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(req)
|
||||||
|
// zero out question section, wtf.
|
||||||
|
m.Question = nil
|
||||||
|
|
||||||
|
o := new(dns.OPT)
|
||||||
|
o.Hdr.Name = "."
|
||||||
|
o.Hdr.Rrtype = dns.TypeOPT
|
||||||
|
o.SetVersion(0)
|
||||||
|
o.SetExtendedRcode(dns.RcodeBadVers)
|
||||||
|
m.Extra = []dns.RR{o}
|
||||||
|
|
||||||
|
return m, errors.New("EDNS0 BADVERS")
|
||||||
|
}
|
37
middleware/edns_test.go
Normal file
37
middleware/edns_test.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEdns0Version(t *testing.T) {
|
||||||
|
m := ednsMsg()
|
||||||
|
m.Extra[0].(*dns.OPT).SetVersion(2)
|
||||||
|
|
||||||
|
_, err := Edns0Version(m)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected wrong version, but got OK")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEdns0VersionNoEdns(t *testing.T) {
|
||||||
|
m := ednsMsg()
|
||||||
|
m.Extra = nil
|
||||||
|
|
||||||
|
_, err := Edns0Version(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected no error, but got one: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ednsMsg() *dns.Msg {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
o := new(dns.OPT)
|
||||||
|
o.Hdr.Name = "."
|
||||||
|
o.Hdr.Rrtype = dns.TypeOPT
|
||||||
|
m.Extra = append(m.Extra, o)
|
||||||
|
return m
|
||||||
|
}
|
14
middleware/rcode.go
Normal file
14
middleware/rcode.go
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RcodeToString(rcode int) string {
|
||||||
|
if str, ok := dns.RcodeToString[rcode]; ok {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
return "RCODE" + strconv.Itoa(rcode)
|
||||||
|
}
|
|
@ -1,7 +1,6 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
@ -54,27 +53,16 @@ func (r *ResponseRecorder) Write(buf []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size returns the size.
|
// Size returns the size.
|
||||||
func (r *ResponseRecorder) Size() int {
|
func (r *ResponseRecorder) Size() int { return r.size }
|
||||||
return r.size
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rcode returns the rcode.
|
// Rcode returns the rcode.
|
||||||
func (r *ResponseRecorder) Rcode() string {
|
func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) }
|
||||||
if rcode, ok := dns.RcodeToString[r.rcode]; ok {
|
|
||||||
return rcode
|
|
||||||
}
|
|
||||||
return "RCODE" + strconv.Itoa(r.rcode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start returns the start time of the ResponseRecorder.
|
// Start returns the start time of the ResponseRecorder.
|
||||||
func (r *ResponseRecorder) Start() time.Time {
|
func (r *ResponseRecorder) Start() time.Time { return r.start }
|
||||||
return r.start
|
|
||||||
}
|
|
||||||
|
|
||||||
// Msg returns the written message from the ResponseRecorder.
|
// Msg returns the written message from the ResponseRecorder.
|
||||||
func (r *ResponseRecorder) Msg() *dns.Msg {
|
func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg }
|
||||||
return r.msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hijack implements dns.Hijacker. It simply wraps the underlying
|
// Hijack implements dns.Hijacker. It simply wraps the underlying
|
||||||
// ResponseWriter's Hijack method if there is one, or returns an error.
|
// ResponseWriter's Hijack method if there is one, or returns an error.
|
||||||
|
|
|
@ -12,10 +12,10 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/coredns/middleware"
|
||||||
"github.com/miekg/coredns/middleware/chaos"
|
"github.com/miekg/coredns/middleware/chaos"
|
||||||
"github.com/miekg/coredns/middleware/prometheus"
|
"github.com/miekg/coredns/middleware/prometheus"
|
||||||
|
|
||||||
|
@ -279,6 +279,14 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once.
|
||||||
|
qtype := dns.Type(r.Question[0].Qtype).String()
|
||||||
|
rc := middleware.RcodeToString(dns.RcodeBadVers)
|
||||||
|
metrics.Report(dropped, qtype, rc, m.Len(), time.Now())
|
||||||
|
w.WriteMsg(m)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Execute the optional request callback if it exists
|
// Execute the optional request callback if it exists
|
||||||
if s.ReqCallback != nil && s.ReqCallback(w, r) {
|
if s.ReqCallback != nil && s.ReqCallback(w, r) {
|
||||||
return
|
return
|
||||||
|
@ -332,12 +340,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
// of the specified HTTP status code.
|
// of the specified HTTP status code.
|
||||||
func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
|
func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
|
||||||
qtype := dns.Type(r.Question[0].Qtype).String()
|
qtype := dns.Type(r.Question[0].Qtype).String()
|
||||||
|
rc := middleware.RcodeToString(rcode)
|
||||||
// this code is duplicated a few times, TODO(miek)
|
|
||||||
rc := dns.RcodeToString[rcode]
|
|
||||||
if rc == "" {
|
|
||||||
rc = "RCODE" + strconv.Itoa(rcode)
|
|
||||||
}
|
|
||||||
|
|
||||||
answer := new(dns.Msg)
|
answer := new(dns.Msg)
|
||||||
answer.SetRcode(r, rcode)
|
answer.SetRcode(r, rcode)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue