diff --git a/middleware/kubernetes/handler.go b/middleware/kubernetes/handler.go index f35792881..12277911a 100644 --- a/middleware/kubernetes/handler.go +++ b/middleware/kubernetes/handler.go @@ -26,7 +26,12 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M // otherwise delegate to the next in the pipeline. zone := middleware.Zones(k.Zones).Matches(state.Name()) if zone == "" { - return middleware.NextOrFailure(k.Name(), k.Next, ctx, w, r) + // If this is a PTR request, and a the request is in a defined + // pod/service cidr range, process the request in this middleware, + // otherwise pass to next middleware. + if state.Type() != "PTR" || !k.IsRequestInReverseRange(state) { + return middleware.NextOrFailure(k.Name(), k.Next, ctx, w, r) + } } var ( diff --git a/middleware/kubernetes/kubernetes.go b/middleware/kubernetes/kubernetes.go index 7ad2403c6..8c4e08e5d 100644 --- a/middleware/kubernetes/kubernetes.go +++ b/middleware/kubernetes/kubernetes.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "net" "strings" "time" @@ -41,6 +42,7 @@ type Kubernetes struct { LabelSelector *unversionedapi.LabelSelector Selector *labels.Selector PodMode string + ReverseCidrs []net.IPNet } const ( @@ -128,6 +130,16 @@ func (k *Kubernetes) Reverse(state request.Request, exact bool, opt middleware.O return records, nil, nil } +func (k *Kubernetes) IsRequestInReverseRange(state request.Request) bool { + ip := dnsutil.ExtractAddressFromReverse(state.Name()) + for _, c := range k.ReverseCidrs { + if c.Contains(net.ParseIP(ip)) { + return true + } + } + return false +} + // Lookup implements the ServiceBackend interface. func (k *Kubernetes) Lookup(state request.Request, name string, typ uint16) (*dns.Msg, error) { return k.Proxy.Lookup(state, name, typ) diff --git a/middleware/kubernetes/setup.go b/middleware/kubernetes/setup.go index 7fd9804e0..366084ada 100644 --- a/middleware/kubernetes/setup.go +++ b/middleware/kubernetes/setup.go @@ -3,6 +3,7 @@ package kubernetes import ( "errors" "fmt" + "net" "strings" "time" @@ -84,6 +85,20 @@ func kubernetesParse(c *caddy.Controller) (*Kubernetes, error) { for c.NextBlock() { switch c.Val() { + case "cidrs": + args := c.RemainingArgs() + if len(args) > 0 { + for _, cidrStr := range args { + _, cidr, err := net.ParseCIDR(cidrStr) + if err != nil { + return nil, errors.New(c.Val() + " contains an invalid cidr: " + cidrStr) + } + k8s.ReverseCidrs = append(k8s.ReverseCidrs, *cidr) + + } + continue + } + return nil, c.ArgErr() case "pods": args := c.RemainingArgs() if len(args) == 1 {