diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go index 8e2f0ba76..3f01cac5f 100644 --- a/core/dnsserver/server.go +++ b/core/dnsserver/server.go @@ -289,12 +289,6 @@ func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rc int) { answer := new(dns.Msg) answer.SetRcode(r, rc) - if r == nil { - log.Printf("[WARNING] DefaultErrorFunc called with nil *dns.Msg (Remote: %s)", w.RemoteAddr().String()) - w.WriteMsg(answer) - return - } - state.SizeAndDo(answer) vars.Report(state, vars.Dropped, rcode.ToString(rc), answer.Len(), time.Now()) diff --git a/request/request.go b/request/request.go index 89b4a4cf4..a2428c464 100644 --- a/request/request.go +++ b/request/request.go @@ -205,31 +205,93 @@ func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) { return reply, ScrubDone } -// Type returns the type of the question as a string. -func (r *Request) Type() string { return dns.Type(r.Req.Question[0].Qtype).String() } +// Type returns the type of the question as a string. If the request is malformed +// the empty string is returned. +func (r *Request) Type() string { + if r.Req == nil { + return "" + } + if len(r.Req.Question) == 0 { + return "" + } -// QType returns the type of the question as an uint16. -func (r *Request) QType() uint16 { return r.Req.Question[0].Qtype } + return dns.Type(r.Req.Question[0].Qtype).String() +} + +// QType returns the type of the question as an uint16. If the request is malformed +// 0 is returned. +func (r *Request) QType() uint16 { + if r.Req == nil { + return 0 + } + if len(r.Req.Question) == 0 { + return 0 + } + + return r.Req.Question[0].Qtype +} // Name returns the name of the question in the request. Note // this name will always have a closing dot and will be lower cased. After a call Name // the value will be cached. To clear this caching call Clear. +// If the request is malformed the root zone is returned. func (r *Request) Name() string { if r.name != "" { return r.name } + if r.Req == nil { + r.name = "." + return "." + } + if len(r.Req.Question) == 0 { + r.name = "." + return "." + } + r.name = strings.ToLower(dns.Name(r.Req.Question[0].Name).String()) return r.name } // QName returns the name of the question in the request. -func (r *Request) QName() string { return dns.Name(r.Req.Question[0].Name).String() } +// If the request is malformed the root zone is returned. +func (r *Request) QName() string { + if r.Req == nil { + return "." + } + if len(r.Req.Question) == 0 { + return "." + } + + return dns.Name(r.Req.Question[0].Name).String() +} // Class returns the class of the question in the request. -func (r *Request) Class() string { return dns.Class(r.Req.Question[0].Qclass).String() } +// If the request is malformed the empty string is returned. +func (r *Request) Class() string { + if r.Req == nil { + return "" + } + if len(r.Req.Question) == 0 { + return "" + } + + return dns.Class(r.Req.Question[0].Qclass).String() + +} // QClass returns the class of the question in the request. -func (r *Request) QClass() uint16 { return r.Req.Question[0].Qclass } +// If the request is malformed 0 returned. +func (r *Request) QClass() uint16 { + if r.Req == nil { + return 0 + } + if len(r.Req.Question) == 0 { + return 0 + } + + return r.Req.Question[0].Qclass + +} // ErrorMessage returns an error message suitable for sending // back to the client. diff --git a/request/request_test.go b/request/request_test.go index 563e6dfe2..bc08ea6ab 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -31,6 +31,35 @@ func TestRequestRemote(t *testing.T) { } } +func TestRequestMalformed(t *testing.T) { + m := new(dns.Msg) + st := Request{Req: m} + + if x := st.QType(); x != 0 { + t.Errorf("Expected 0 Qtype, got %d", x) + } + + if x := st.QClass(); x != 0 { + t.Errorf("Expected 0 QClass, got %d", x) + } + + if x := st.QName(); x != "." { + t.Errorf("Expected . Qname, got %s", x) + } + + if x := st.Name(); x != "." { + t.Errorf("Expected . Name, got %s", x) + } + + if x := st.Type(); x != "" { + t.Errorf("Expected empty Type, got %s", x) + } + + if x := st.Class(); x != "" { + t.Errorf("Expected empty Class, got %s", x) + } +} + func BenchmarkRequestDo(b *testing.B) { st := testRequest()