Add middleware.NextOrFailure (#462)
This checks if the next middleware to be called is nil, and if so returns ServerFailure and an error. This makes the next calling more robust and saves some lines of code. Also prefix the error with the name of the middleware to aid in debugging.
This commit is contained in:
parent
451a0bd529
commit
c4ab98c6e3
23 changed files with 51 additions and 67 deletions
|
@ -44,7 +44,7 @@ type (
|
|||
func (a Auto) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
state := request.Request{W: w, Req: r}
|
||||
if state.QClass() != dns.ClassINET {
|
||||
return dns.RcodeServerFailure, errors.New("can only deal with ClassINET")
|
||||
return dns.RcodeServerFailure, middleware.Error(a.Name(), errors.New("can only deal with ClassINET"))
|
||||
}
|
||||
qname := state.Name()
|
||||
|
||||
|
@ -53,10 +53,7 @@ func (a Auto) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i
|
|||
// Precheck with the origins, i.e. are we allowed to looks here.
|
||||
zone := middleware.Zones(a.Zones.Origins()).Matches(qname)
|
||||
if zone == "" {
|
||||
if a.Next != nil {
|
||||
return a.Next.ServeDNS(ctx, w, r)
|
||||
}
|
||||
return dns.RcodeServerFailure, errors.New("no next middleware found")
|
||||
return middleware.NextOrFailure(a.Name(), a.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
// Now the real zone.
|
||||
|
|
|
@ -410,7 +410,7 @@ func BackendError(b ServiceBackend, zone string, rcode int, state request.Reques
|
|||
state.SizeAndDo(m)
|
||||
state.W.WriteMsg(m)
|
||||
// Return success as the rcode to signal we have written to the client.
|
||||
return dns.RcodeSuccess, nil
|
||||
return dns.RcodeSuccess, err
|
||||
}
|
||||
|
||||
// ServicesToTxt puts debug in TXT RRs.
|
||||
|
|
2
middleware/cache/handler.go
vendored
2
middleware/cache/handler.go
vendored
|
@ -34,7 +34,7 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
|||
}
|
||||
|
||||
crr := &ResponseWriter{w, c}
|
||||
return c.Next.ServeDNS(ctx, crr, r)
|
||||
return middleware.NextOrFailure(c.Name(), c.Next, ctx, crr, r)
|
||||
}
|
||||
|
||||
// Name implements the Handler interface.
|
||||
|
|
|
@ -23,7 +23,7 @@ type Chaos struct {
|
|||
func (c Chaos) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
state := request.Request{W: w, Req: r}
|
||||
if state.QClass() != dns.ClassCHAOS || state.QType() != dns.TypeTXT {
|
||||
return c.Next.ServeDNS(ctx, w, r)
|
||||
return middleware.NextOrFailure(c.Name(), c.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
|
|
|
@ -18,7 +18,7 @@ func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
|||
qtype := state.QType()
|
||||
zone := middleware.Zones(d.zones).Matches(qname)
|
||||
if zone == "" {
|
||||
return d.Next.ServeDNS(ctx, w, r)
|
||||
return middleware.NextOrFailure(d.Name(), d.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
// Intercept queries for DNSKEY, but only if one of the zones matches the qname, otherwise we let
|
||||
|
@ -36,7 +36,7 @@ func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
|||
}
|
||||
|
||||
drr := &ResponseWriter{w, d}
|
||||
return d.Next.ServeDNS(ctx, drr, r)
|
||||
return middleware.NextOrFailure(d.Name(), d.Next, ctx, drr, r)
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
@ -27,7 +27,7 @@ type errorHandler struct {
|
|||
func (h errorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
defer h.recovery(ctx, w, r)
|
||||
|
||||
rcode, err := h.Next.ServeDNS(ctx, w, r)
|
||||
rcode, err := middleware.NextOrFailure(h.Name(), h.Next, ctx, w, r)
|
||||
|
||||
if err != nil {
|
||||
state := request.Request{W: w, Req: r}
|
||||
|
|
|
@ -26,11 +26,7 @@ func TestDebugLookup(t *testing.T) {
|
|||
m := tc.Msg()
|
||||
|
||||
rec := dnsrecorder.New(&test.ResponseWriter{})
|
||||
_, err := etc.ServeDNS(ctxt, rec, m)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v\n", err)
|
||||
continue
|
||||
}
|
||||
etc.ServeDNS(ctxt, rec, m)
|
||||
|
||||
resp := rec.Msg
|
||||
sort.Sort(test.RRSet(resp.Answer))
|
||||
|
@ -64,11 +60,7 @@ func TestDebugLookupFalse(t *testing.T) {
|
|||
m := tc.Msg()
|
||||
|
||||
rec := dnsrecorder.New(&test.ResponseWriter{})
|
||||
_, err := etc.ServeDNS(ctxt, rec, m)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v\n", err)
|
||||
continue
|
||||
}
|
||||
etc.ServeDNS(ctxt, rec, m)
|
||||
|
||||
resp := rec.Msg
|
||||
sort.Sort(test.RRSet(resp.Answer))
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package etcd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/coredns/middleware/etcd/msg"
|
||||
|
@ -18,7 +18,7 @@ func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
|
|||
opt := middleware.Options{}
|
||||
state := request.Request{W: w, Req: r}
|
||||
if state.QClass() != dns.ClassINET {
|
||||
return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET")
|
||||
return dns.RcodeServerFailure, middleware.Error(e.Name(), errors.New("can only deal with ClassINET"))
|
||||
}
|
||||
name := state.Name()
|
||||
if e.Debugging {
|
||||
|
@ -43,13 +43,10 @@ func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
|
|||
|
||||
zone := middleware.Zones(e.Zones).Matches(state.Name())
|
||||
if zone == "" {
|
||||
if e.Next == nil {
|
||||
return dns.RcodeServerFailure, nil
|
||||
}
|
||||
if opt.Debug != "" {
|
||||
r.Question[0].Name = opt.Debug
|
||||
}
|
||||
return e.Next.ServeDNS(ctx, w, r)
|
||||
return middleware.NextOrFailure(e.Name(), e.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
@ -17,7 +17,6 @@ import (
|
|||
|
||||
etcdc "github.com/coreos/etcd/client"
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
|
@ -66,11 +65,7 @@ func TestLookup(t *testing.T) {
|
|||
m := tc.Msg()
|
||||
|
||||
rec := dnsrecorder.New(&test.ResponseWriter{})
|
||||
_, err := etc.ServeDNS(ctxt, rec, m)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got: %v for %s %s\n", err, m.Question[0].Name, dns.Type(m.Question[0].Qtype))
|
||||
return
|
||||
}
|
||||
etc.ServeDNS(ctxt, rec, m)
|
||||
|
||||
resp := rec.Msg
|
||||
sort.Sort(test.RRSet(resp.Answer))
|
||||
|
|
|
@ -32,16 +32,13 @@ func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (i
|
|||
state := request.Request{W: w, Req: r}
|
||||
|
||||
if state.QClass() != dns.ClassINET {
|
||||
return dns.RcodeServerFailure, errors.New("can only deal with ClassINET")
|
||||
return dns.RcodeServerFailure, middleware.Error(f.Name(), errors.New("can only deal with ClassINET"))
|
||||
}
|
||||
qname := state.Name()
|
||||
// TODO(miek): match the qname better in the map
|
||||
zone := middleware.Zones(f.Zones.Names).Matches(qname)
|
||||
if zone == "" {
|
||||
if f.Next != nil {
|
||||
return f.Next.ServeDNS(ctx, w, r)
|
||||
}
|
||||
return dns.RcodeServerFailure, errors.New("no next middleware found")
|
||||
return middleware.NextOrFailure(f.Name(), f.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
z, ok := f.Zones.Z[zone]
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -22,7 +23,7 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in
|
|||
return dns.RcodeServerFailure, nil
|
||||
}
|
||||
if state.QType() != dns.TypeAXFR && state.QType() != dns.TypeIXFR {
|
||||
return 0, fmt.Errorf("xfr called with non transfer type: %d", state.QType())
|
||||
return 0, middleware.Error(x.Name(), fmt.Errorf("xfr called with non transfer type: %d", state.QType()))
|
||||
}
|
||||
|
||||
records := x.All()
|
||||
|
@ -55,4 +56,7 @@ func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (in
|
|||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
|
||||
// Name implements the middleware.Hander interface.
|
||||
func (x Xfr) Name() string { return "xfr" } // Or should we return "file" here?
|
||||
|
||||
const transferLength = 1000 // Start a new envelop after message reaches this size in bytes. Intentionally small to test multi envelope parsing.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package kubernetes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/coredns/middleware/pkg/dnsutil"
|
||||
|
@ -15,7 +15,7 @@ import (
|
|||
func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
state := request.Request{W: w, Req: r}
|
||||
if state.QClass() != dns.ClassINET {
|
||||
return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET")
|
||||
return dns.RcodeServerFailure, middleware.Error(k.Name(), errors.New("can only deal with ClassINET"))
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
|
@ -26,10 +26,7 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M
|
|||
// otherwise delegate to the next in the pipeline.
|
||||
zone := middleware.Zones(k.Zones).Matches(state.Name())
|
||||
if zone == "" {
|
||||
if k.Next == nil {
|
||||
return dns.RcodeServerFailure, nil
|
||||
}
|
||||
return k.Next.ServeDNS(ctx, w, r)
|
||||
return middleware.NextOrFailure(k.Name(), k.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
@ -16,7 +16,7 @@ type RoundRobin struct {
|
|||
// ServeDNS implements the middleware.Handler interface.
|
||||
func (rr RoundRobin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
wrr := &RoundRobinResponseWriter{w}
|
||||
return rr.Next.ServeDNS(ctx, wrr, r)
|
||||
return middleware.NextOrFailure(rr.Name(), rr.Next, ctx, wrr, r)
|
||||
}
|
||||
|
||||
// Name implements the Handler interface.
|
||||
|
|
|
@ -32,14 +32,14 @@ func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
|||
continue
|
||||
}
|
||||
|
||||
responseRecorder := dnsrecorder.New(w)
|
||||
rc, err := l.Next.ServeDNS(ctx, responseRecorder, r)
|
||||
rrw := dnsrecorder.New(w)
|
||||
rc, err := middleware.NextOrFailure(l.Name(), l.Next, ctx, rrw, r)
|
||||
|
||||
if rc > 0 {
|
||||
// There was an error up the chain, but no response has been written yet.
|
||||
// The error must be handled here so the log entry will record the response size.
|
||||
if l.ErrorFunc != nil {
|
||||
l.ErrorFunc(responseRecorder, r, rc)
|
||||
l.ErrorFunc(rrw, r, rc)
|
||||
} else {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(r, rc)
|
||||
|
@ -52,16 +52,16 @@ func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
|||
rc = 0
|
||||
}
|
||||
|
||||
class, _ := response.Classify(responseRecorder.Msg)
|
||||
class, _ := response.Classify(rrw.Msg)
|
||||
if rule.Class == response.All || rule.Class == class {
|
||||
rep := replacer.New(r, responseRecorder, CommonLogEmptyValue)
|
||||
rep := replacer.New(r, rrw, CommonLogEmptyValue)
|
||||
rule.Log.Println(rep.Replace(rule.Format))
|
||||
}
|
||||
|
||||
return rc, err
|
||||
|
||||
}
|
||||
return l.Next.ServeDNS(ctx, w, r)
|
||||
return middleware.NextOrFailure(l.Name(), l.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
// Name implements the Handler interface.
|
||||
|
|
|
@ -23,7 +23,7 @@ func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
|
|||
|
||||
// Record response to get status code and size of the reply.
|
||||
rw := dnsrecorder.New(w)
|
||||
status, err := m.Next.ServeDNS(ctx, rw, r)
|
||||
status, err := middleware.NextOrFailure(m.Name(), m.Next, ctx, rw, r)
|
||||
|
||||
vars.Report(state, zone, rcode.ToString(rw.Rcode), rw.Len, rw.Start)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -65,5 +66,15 @@ func (f HandlerFunc) Name() string { return "handlerfunc" }
|
|||
// Error returns err with 'middleware/name: ' prefixed to it.
|
||||
func Error(name string, err error) error { return fmt.Errorf("%s/%s: %s", "middleware", name, err) }
|
||||
|
||||
// NextOrFailure calls next.ServeDNS when next is not nill, otherwise it will return, a ServerFailure
|
||||
// and a nil error.
|
||||
func NextOrFailure(name string, next Handler, ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
if next != nil {
|
||||
return next.ServeDNS(ctx, w, r)
|
||||
}
|
||||
|
||||
return dns.RcodeServerFailure, Error(name, errors.New("no next middleware found"))
|
||||
}
|
||||
|
||||
// Namespace is the namespace used for the metrics.
|
||||
const Namespace = "coredns"
|
||||
|
|
|
@ -2,6 +2,7 @@ package debug
|
|||
|
||||
import "strings"
|
||||
|
||||
// Name is the domain prefix we check for when it is a debug query.
|
||||
const Name = "o-o.debug."
|
||||
|
||||
// IsDebug checks if name is a debugging name, i.e. starts with o-o.debug.
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// PorseHostPortOrFile parses the strings in s, each string can either be a address,
|
||||
// ParseHostPortOrFile parses the strings in s, each string can either be a address,
|
||||
// address:port or a filename. The address part is checked and the filename case a
|
||||
// resolv.conf like file is parsed and the nameserver found are returned.
|
||||
func ParseHostPortOrFile(s ...string) ([]string, error) {
|
||||
|
|
|
@ -108,7 +108,7 @@ func (p Proxy) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
|
|||
|
||||
return dns.RcodeServerFailure, errUnreachable
|
||||
}
|
||||
return p.Next.ServeDNS(ctx, w, r)
|
||||
return middleware.NextOrFailure(p.Name(), p.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
// Name implements the Handler interface.
|
||||
|
|
|
@ -37,9 +37,9 @@ func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
|
|||
switch result := rule.Rewrite(r); result {
|
||||
case RewriteDone:
|
||||
if rw.noRevert {
|
||||
return rw.Next.ServeDNS(ctx, w, r)
|
||||
return middleware.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
|
||||
}
|
||||
return rw.Next.ServeDNS(ctx, wr, r)
|
||||
return middleware.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
|
||||
case RewriteIgnored:
|
||||
break
|
||||
case RewriteStatus:
|
||||
|
@ -49,7 +49,7 @@ func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
|
|||
// }
|
||||
}
|
||||
}
|
||||
return rw.Next.ServeDNS(ctx, w, r)
|
||||
return middleware.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
// Name implements the Handler interface.
|
||||
|
|
|
@ -21,7 +21,7 @@ func setupWhoami(c *caddy.Controller) error {
|
|||
}
|
||||
|
||||
dnsserver.GetConfig(c).AddMiddleware(func(next middleware.Handler) middleware.Handler {
|
||||
return Whoami{Next: next}
|
||||
return Whoami{}
|
||||
})
|
||||
|
||||
return nil
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -15,9 +14,7 @@ import (
|
|||
|
||||
// Whoami is a middleware that returns your IP address, port and the protocol used for connecting
|
||||
// to CoreDNS.
|
||||
type Whoami struct {
|
||||
Next middleware.Handler
|
||||
}
|
||||
type Whoami struct{}
|
||||
|
||||
// ServeDNS implements the middleware.Handler interface.
|
||||
func (wh Whoami) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
|
|
|
@ -3,7 +3,6 @@ package whoami
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/coredns/middleware"
|
||||
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
|
||||
"github.com/miekg/coredns/middleware/test"
|
||||
|
||||
|
@ -15,7 +14,6 @@ func TestWhoami(t *testing.T) {
|
|||
wh := Whoami{}
|
||||
|
||||
tests := []struct {
|
||||
next middleware.Handler
|
||||
qname string
|
||||
qtype uint16
|
||||
expectedCode int
|
||||
|
@ -23,7 +21,6 @@ func TestWhoami(t *testing.T) {
|
|||
expectedErr error
|
||||
}{
|
||||
{
|
||||
next: test.NextHandler(dns.RcodeSuccess, nil),
|
||||
qname: "example.org",
|
||||
qtype: dns.TypeA,
|
||||
expectedCode: dns.RcodeSuccess,
|
||||
|
@ -35,7 +32,6 @@ func TestWhoami(t *testing.T) {
|
|||
ctx := context.TODO()
|
||||
|
||||
for i, tc := range tests {
|
||||
wh.Next = tc.next
|
||||
req := new(dns.Msg)
|
||||
req.SetQuestion(dns.Fqdn(tc.qname), tc.qtype)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue