From c4ab98c6e336a1c39b3934bbb3bf691f849a6dbe Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Tue, 20 Dec 2016 18:58:05 +0000 Subject: [PATCH] Add middleware.NextOrFailure (#462) This checks if the next middleware to be called is nil, and if so returns ServerFailure and an error. This makes the next calling more robust and saves some lines of code. Also prefix the error with the name of the middleware to aid in debugging. --- middleware/auto/auto.go | 7 ++----- middleware/backend_lookup.go | 2 +- middleware/cache/handler.go | 2 +- middleware/chaos/chaos.go | 2 +- middleware/dnssec/handler.go | 4 ++-- middleware/errors/errors.go | 2 +- middleware/etcd/debug_test.go | 12 ++---------- middleware/etcd/handler.go | 9 +++------ middleware/etcd/setup_test.go | 7 +------ middleware/file/file.go | 7 ++----- middleware/file/xfr.go | 6 +++++- middleware/kubernetes/handler.go | 9 +++------ middleware/loadbalance/handler.go | 2 +- middleware/log/log.go | 12 ++++++------ middleware/metrics/handler.go | 2 +- middleware/middleware.go | 11 +++++++++++ middleware/pkg/debug/debug.go | 1 + middleware/pkg/dnsutil/host.go | 2 +- middleware/proxy/proxy.go | 2 +- middleware/rewrite/rewrite.go | 6 +++--- middleware/whoami/setup.go | 2 +- middleware/whoami/whoami.go | 5 +---- middleware/whoami/whoami_test.go | 4 ---- 23 files changed, 51 insertions(+), 67 deletions(-) diff --git a/middleware/auto/auto.go b/middleware/auto/auto.go index 115e86dea..0557fc3d1 100644 --- a/middleware/auto/auto.go +++ b/middleware/auto/auto.go @@ -44,7 +44,7 @@ type ( func (a Auto) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassINET { - return dns.RcodeServerFailure, errors.New("can only deal with ClassINET") + return dns.RcodeServerFailure, middleware.Error(a.Name(), errors.New("can only deal with ClassINET")) } qname := state.Name() @@ -53,10 +53,7 @@ func (a Auto) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i // Precheck with the origins, i.e. are we allowed to looks here. zone := middleware.Zones(a.Zones.Origins()).Matches(qname) if zone == "" { - if a.Next != nil { - return a.Next.ServeDNS(ctx, w, r) - } - return dns.RcodeServerFailure, errors.New("no next middleware found") + return middleware.NextOrFailure(a.Name(), a.Next, ctx, w, r) } // Now the real zone. diff --git a/middleware/backend_lookup.go b/middleware/backend_lookup.go index c52bb2881..b9de97d67 100644 --- a/middleware/backend_lookup.go +++ b/middleware/backend_lookup.go @@ -410,7 +410,7 @@ func BackendError(b ServiceBackend, zone string, rcode int, state request.Reques state.SizeAndDo(m) state.W.WriteMsg(m) // Return success as the rcode to signal we have written to the client. - return dns.RcodeSuccess, nil + return dns.RcodeSuccess, err } // ServicesToTxt puts debug in TXT RRs. diff --git a/middleware/cache/handler.go b/middleware/cache/handler.go index 77a0cea48..fa2d60ca7 100644 --- a/middleware/cache/handler.go +++ b/middleware/cache/handler.go @@ -34,7 +34,7 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) } crr := &ResponseWriter{w, c} - return c.Next.ServeDNS(ctx, crr, r) + return middleware.NextOrFailure(c.Name(), c.Next, ctx, crr, r) } // Name implements the Handler interface. diff --git a/middleware/chaos/chaos.go b/middleware/chaos/chaos.go index 730ef21b5..cad9e0445 100644 --- a/middleware/chaos/chaos.go +++ b/middleware/chaos/chaos.go @@ -23,7 +23,7 @@ type Chaos struct { func (c Chaos) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassCHAOS || state.QType() != dns.TypeTXT { - return c.Next.ServeDNS(ctx, w, r) + return middleware.NextOrFailure(c.Name(), c.Next, ctx, w, r) } m := new(dns.Msg) diff --git a/middleware/dnssec/handler.go b/middleware/dnssec/handler.go index 4c1621c8b..9a45b88b2 100644 --- a/middleware/dnssec/handler.go +++ b/middleware/dnssec/handler.go @@ -18,7 +18,7 @@ func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) qtype := state.QType() zone := middleware.Zones(d.zones).Matches(qname) if zone == "" { - return d.Next.ServeDNS(ctx, w, r) + return middleware.NextOrFailure(d.Name(), d.Next, ctx, w, r) } // Intercept queries for DNSKEY, but only if one of the zones matches the qname, otherwise we let @@ -36,7 +36,7 @@ func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) } drr := &ResponseWriter{w, d} - return d.Next.ServeDNS(ctx, drr, r) + return middleware.NextOrFailure(d.Name(), d.Next, ctx, drr, r) } var ( diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go index aca05b54a..deba2075d 100644 --- a/middleware/errors/errors.go +++ b/middleware/errors/errors.go @@ -27,7 +27,7 @@ type errorHandler struct { func (h errorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { defer h.recovery(ctx, w, r) - rcode, err := h.Next.ServeDNS(ctx, w, r) + rcode, err := middleware.NextOrFailure(h.Name(), h.Next, ctx, w, r) if err != nil { state := request.Request{W: w, Req: r} diff --git a/middleware/etcd/debug_test.go b/middleware/etcd/debug_test.go index aa26dd846..7ea91ecbe 100644 --- a/middleware/etcd/debug_test.go +++ b/middleware/etcd/debug_test.go @@ -26,11 +26,7 @@ func TestDebugLookup(t *testing.T) { m := tc.Msg() rec := dnsrecorder.New(&test.ResponseWriter{}) - _, err := etc.ServeDNS(ctxt, rec, m) - if err != nil { - t.Errorf("expected no error, got %v\n", err) - continue - } + etc.ServeDNS(ctxt, rec, m) resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) @@ -64,11 +60,7 @@ func TestDebugLookupFalse(t *testing.T) { m := tc.Msg() rec := dnsrecorder.New(&test.ResponseWriter{}) - _, err := etc.ServeDNS(ctxt, rec, m) - if err != nil { - t.Errorf("expected no error, got %v\n", err) - continue - } + etc.ServeDNS(ctxt, rec, m) resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) diff --git a/middleware/etcd/handler.go b/middleware/etcd/handler.go index b462d7c5b..331e57793 100644 --- a/middleware/etcd/handler.go +++ b/middleware/etcd/handler.go @@ -1,7 +1,7 @@ package etcd import ( - "fmt" + "errors" "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" @@ -18,7 +18,7 @@ func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( opt := middleware.Options{} state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassINET { - return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET") + return dns.RcodeServerFailure, middleware.Error(e.Name(), errors.New("can only deal with ClassINET")) } name := state.Name() if e.Debugging { @@ -43,13 +43,10 @@ func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( zone := middleware.Zones(e.Zones).Matches(state.Name()) if zone == "" { - if e.Next == nil { - return dns.RcodeServerFailure, nil - } if opt.Debug != "" { r.Question[0].Name = opt.Debug } - return e.Next.ServeDNS(ctx, w, r) + return middleware.NextOrFailure(e.Name(), e.Next, ctx, w, r) } var ( diff --git a/middleware/etcd/setup_test.go b/middleware/etcd/setup_test.go index c1e33109c..e7e51f065 100644 --- a/middleware/etcd/setup_test.go +++ b/middleware/etcd/setup_test.go @@ -17,7 +17,6 @@ import ( etcdc "github.com/coreos/etcd/client" "github.com/mholt/caddy" - "github.com/miekg/dns" "golang.org/x/net/context" ) @@ -66,11 +65,7 @@ func TestLookup(t *testing.T) { m := tc.Msg() rec := dnsrecorder.New(&test.ResponseWriter{}) - _, err := etc.ServeDNS(ctxt, rec, m) - if err != nil { - t.Errorf("expected no error, got: %v for %s %s\n", err, m.Question[0].Name, dns.Type(m.Question[0].Qtype)) - return - } + etc.ServeDNS(ctxt, rec, m) resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) diff --git a/middleware/file/file.go b/middleware/file/file.go index 6a171740f..13ef894d3 100644 --- a/middleware/file/file.go +++ b/middleware/file/file.go @@ -32,16 +32,13 @@ func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassINET { - return dns.RcodeServerFailure, errors.New("can only deal with ClassINET") + return dns.RcodeServerFailure, middleware.Error(f.Name(), errors.New("can only deal with ClassINET")) } qname := state.Name() // TODO(miek): match the qname better in the map zone := middleware.Zones(f.Zones.Names).Matches(qname) if zone == "" { - if f.Next != nil { - return f.Next.ServeDNS(ctx, w, r) - } - return dns.RcodeServerFailure, errors.New("no next middleware found") + return middleware.NextOrFailure(f.Name(), f.Next, ctx, w, r) } z, ok := f.Zones.Z[zone] diff --git a/middleware/file/xfr.go b/middleware/file/xfr.go index e4fcd7efa..3cb21aa11 100644 --- a/middleware/file/xfr.go +++ b/middleware/file/xfr.go @@ -4,6 +4,7 @@ import ( "fmt" "log" + "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/request" "github.com/miekg/dns" @@ -22,7 +23,7 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in return dns.RcodeServerFailure, nil } if state.QType() != dns.TypeAXFR && state.QType() != dns.TypeIXFR { - return 0, fmt.Errorf("xfr called with non transfer type: %d", state.QType()) + return 0, middleware.Error(x.Name(), fmt.Errorf("xfr called with non transfer type: %d", state.QType())) } records := x.All() @@ -55,4 +56,7 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in return dns.RcodeSuccess, nil } +// Name implements the middleware.Hander interface. +func (x Xfr) Name() string { return "xfr" } // Or should we return "file" here? + const transferLength = 1000 // Start a new envelop after message reaches this size in bytes. Intentionally small to test multi envelope parsing. diff --git a/middleware/kubernetes/handler.go b/middleware/kubernetes/handler.go index e21ea6d58..f35792881 100644 --- a/middleware/kubernetes/handler.go +++ b/middleware/kubernetes/handler.go @@ -1,7 +1,7 @@ package kubernetes import ( - "fmt" + "errors" "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/dnsutil" @@ -15,7 +15,7 @@ import ( func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassINET { - return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET") + return dns.RcodeServerFailure, middleware.Error(k.Name(), errors.New("can only deal with ClassINET")) } m := new(dns.Msg) @@ -26,10 +26,7 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M // otherwise delegate to the next in the pipeline. zone := middleware.Zones(k.Zones).Matches(state.Name()) if zone == "" { - if k.Next == nil { - return dns.RcodeServerFailure, nil - } - return k.Next.ServeDNS(ctx, w, r) + return middleware.NextOrFailure(k.Name(), k.Next, ctx, w, r) } var ( diff --git a/middleware/loadbalance/handler.go b/middleware/loadbalance/handler.go index 9f9cfb766..9b4baf2ed 100644 --- a/middleware/loadbalance/handler.go +++ b/middleware/loadbalance/handler.go @@ -16,7 +16,7 @@ type RoundRobin struct { // ServeDNS implements the middleware.Handler interface. func (rr RoundRobin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { wrr := &RoundRobinResponseWriter{w} - return rr.Next.ServeDNS(ctx, wrr, r) + return middleware.NextOrFailure(rr.Name(), rr.Next, ctx, wrr, r) } // Name implements the Handler interface. diff --git a/middleware/log/log.go b/middleware/log/log.go index ef9e3f51d..1ba594d0f 100644 --- a/middleware/log/log.go +++ b/middleware/log/log.go @@ -32,14 +32,14 @@ func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) continue } - responseRecorder := dnsrecorder.New(w) - rc, err := l.Next.ServeDNS(ctx, responseRecorder, r) + rrw := dnsrecorder.New(w) + rc, err := middleware.NextOrFailure(l.Name(), l.Next, ctx, rrw, r) if rc > 0 { // There was an error up the chain, but no response has been written yet. // The error must be handled here so the log entry will record the response size. if l.ErrorFunc != nil { - l.ErrorFunc(responseRecorder, r, rc) + l.ErrorFunc(rrw, r, rc) } else { answer := new(dns.Msg) answer.SetRcode(r, rc) @@ -52,16 +52,16 @@ func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) rc = 0 } - class, _ := response.Classify(responseRecorder.Msg) + class, _ := response.Classify(rrw.Msg) if rule.Class == response.All || rule.Class == class { - rep := replacer.New(r, responseRecorder, CommonLogEmptyValue) + rep := replacer.New(r, rrw, CommonLogEmptyValue) rule.Log.Println(rep.Replace(rule.Format)) } return rc, err } - return l.Next.ServeDNS(ctx, w, r) + return middleware.NextOrFailure(l.Name(), l.Next, ctx, w, r) } // Name implements the Handler interface. diff --git a/middleware/metrics/handler.go b/middleware/metrics/handler.go index 11998165e..8b8f4d419 100644 --- a/middleware/metrics/handler.go +++ b/middleware/metrics/handler.go @@ -23,7 +23,7 @@ func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg // Record response to get status code and size of the reply. rw := dnsrecorder.New(w) - status, err := m.Next.ServeDNS(ctx, rw, r) + status, err := middleware.NextOrFailure(m.Name(), m.Next, ctx, rw, r) vars.Report(state, zone, rcode.ToString(rw.Rcode), rw.Len, rw.Start) diff --git a/middleware/middleware.go b/middleware/middleware.go index 47d732f4d..da0107cd3 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -2,6 +2,7 @@ package middleware import ( + "errors" "fmt" "github.com/miekg/dns" @@ -65,5 +66,15 @@ func (f HandlerFunc) Name() string { return "handlerfunc" } // Error returns err with 'middleware/name: ' prefixed to it. func Error(name string, err error) error { return fmt.Errorf("%s/%s: %s", "middleware", name, err) } +// NextOrFailure calls next.ServeDNS when next is not nill, otherwise it will return, a ServerFailure +// and a nil error. +func NextOrFailure(name string, next Handler, ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if next != nil { + return next.ServeDNS(ctx, w, r) + } + + return dns.RcodeServerFailure, Error(name, errors.New("no next middleware found")) +} + // Namespace is the namespace used for the metrics. const Namespace = "coredns" diff --git a/middleware/pkg/debug/debug.go b/middleware/pkg/debug/debug.go index b3c33b344..186872f12 100644 --- a/middleware/pkg/debug/debug.go +++ b/middleware/pkg/debug/debug.go @@ -2,6 +2,7 @@ package debug import "strings" +// Name is the domain prefix we check for when it is a debug query. const Name = "o-o.debug." // IsDebug checks if name is a debugging name, i.e. starts with o-o.debug. diff --git a/middleware/pkg/dnsutil/host.go b/middleware/pkg/dnsutil/host.go index e38eb9e08..aaab586e8 100644 --- a/middleware/pkg/dnsutil/host.go +++ b/middleware/pkg/dnsutil/host.go @@ -8,7 +8,7 @@ import ( "github.com/miekg/dns" ) -// PorseHostPortOrFile parses the strings in s, each string can either be a address, +// ParseHostPortOrFile parses the strings in s, each string can either be a address, // address:port or a filename. The address part is checked and the filename case a // resolv.conf like file is parsed and the nameserver found are returned. func ParseHostPortOrFile(s ...string) ([]string, error) { diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 666430c9f..353e82c19 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -108,7 +108,7 @@ func (p Proxy) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( return dns.RcodeServerFailure, errUnreachable } - return p.Next.ServeDNS(ctx, w, r) + return middleware.NextOrFailure(p.Name(), p.Next, ctx, w, r) } // Name implements the Handler interface. diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index ad28287ac..adbdbca15 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -37,9 +37,9 @@ func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg switch result := rule.Rewrite(r); result { case RewriteDone: if rw.noRevert { - return rw.Next.ServeDNS(ctx, w, r) + return middleware.NextOrFailure(rw.Name(), rw.Next, ctx, w, r) } - return rw.Next.ServeDNS(ctx, wr, r) + return middleware.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r) case RewriteIgnored: break case RewriteStatus: @@ -49,7 +49,7 @@ func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg // } } } - return rw.Next.ServeDNS(ctx, w, r) + return middleware.NextOrFailure(rw.Name(), rw.Next, ctx, w, r) } // Name implements the Handler interface. diff --git a/middleware/whoami/setup.go b/middleware/whoami/setup.go index 285060bbe..90e5dd4ae 100644 --- a/middleware/whoami/setup.go +++ b/middleware/whoami/setup.go @@ -21,7 +21,7 @@ func setupWhoami(c *caddy.Controller) error { } dnsserver.GetConfig(c).AddMiddleware(func(next middleware.Handler) middleware.Handler { - return Whoami{Next: next} + return Whoami{} }) return nil diff --git a/middleware/whoami/whoami.go b/middleware/whoami/whoami.go index c86d462c8..01af7107f 100644 --- a/middleware/whoami/whoami.go +++ b/middleware/whoami/whoami.go @@ -6,7 +6,6 @@ import ( "net" "strconv" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/request" "github.com/miekg/dns" @@ -15,9 +14,7 @@ import ( // Whoami is a middleware that returns your IP address, port and the protocol used for connecting // to CoreDNS. -type Whoami struct { - Next middleware.Handler -} +type Whoami struct{} // ServeDNS implements the middleware.Handler interface. func (wh Whoami) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { diff --git a/middleware/whoami/whoami_test.go b/middleware/whoami/whoami_test.go index f5376bde1..736cdf18a 100644 --- a/middleware/whoami/whoami_test.go +++ b/middleware/whoami/whoami_test.go @@ -3,7 +3,6 @@ package whoami import ( "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" @@ -15,7 +14,6 @@ func TestWhoami(t *testing.T) { wh := Whoami{} tests := []struct { - next middleware.Handler qname string qtype uint16 expectedCode int @@ -23,7 +21,6 @@ func TestWhoami(t *testing.T) { expectedErr error }{ { - next: test.NextHandler(dns.RcodeSuccess, nil), qname: "example.org", qtype: dns.TypeA, expectedCode: dns.RcodeSuccess, @@ -35,7 +32,6 @@ func TestWhoami(t *testing.T) { ctx := context.TODO() for i, tc := range tests { - wh.Next = tc.next req := new(dns.Msg) req.SetQuestion(dns.Fqdn(tc.qname), tc.qtype)