diff --git a/CODEOWNERS b/CODEOWNERS index 0ef38a5a2..d38cc426a 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -30,6 +30,7 @@ go.mod @miekg @chrisohaver @johnbelamaric @yongtang @stp-ip /plugin/etcd/ @miekg @nitisht /plugin/file/ @miekg @yongtang @stp-ip /plugin/forward/ @johnbelamaric @miekg @rdrozhdzh +/plugin/forwardcrd/ @christianang /plugin/geoip/ @miekg @snebel29 /plugin/grpc/ @inigohu @miekg @zouyee /plugin/health/ @fastest963 @miekg @zouyee diff --git a/core/dnsserver/zdirectives.go b/core/dnsserver/zdirectives.go index bca217185..3ba04096b 100644 --- a/core/dnsserver/zdirectives.go +++ b/core/dnsserver/zdirectives.go @@ -53,6 +53,7 @@ var Directives = []string{ "secondary", "etcd", "loop", + "forwardcrd", "forward", "grpc", "erratic", diff --git a/core/plugin/zplugin.go b/core/plugin/zplugin.go index a9167eeaf..8d100bc26 100644 --- a/core/plugin/zplugin.go +++ b/core/plugin/zplugin.go @@ -25,6 +25,7 @@ import ( _ "github.com/coredns/coredns/plugin/etcd" _ "github.com/coredns/coredns/plugin/file" _ "github.com/coredns/coredns/plugin/forward" + _ "github.com/coredns/coredns/plugin/forwardcrd" _ "github.com/coredns/coredns/plugin/geoip" _ "github.com/coredns/coredns/plugin/grpc" _ "github.com/coredns/coredns/plugin/header" diff --git a/plugin.cfg b/plugin.cfg index 628e71412..eacf81f90 100644 --- a/plugin.cfg +++ b/plugin.cfg @@ -62,6 +62,7 @@ auto:auto secondary:secondary etcd:etcd loop:loop +forwardcrd:forwardcrd forward:forward grpc:grpc erratic:erratic diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index 19a469c72..707ad31e1 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "sync/atomic" "time" @@ -16,6 +17,8 @@ import ( "github.com/coredns/coredns/plugin/dnstap" "github.com/coredns/coredns/plugin/metadata" clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/parse" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -54,12 +57,127 @@ type Forward struct { Next plugin.Handler } +// ForwardConfig represents the configuration of the Forward Plugin. This can +// be used with NewWithConfig to create a new configured instance of the +// Forward Plugin. +type ForwardConfig struct { + From string + To []string + Except []string + MaxFails *uint32 + HealthCheck *time.Duration + HealthCheckNoRec bool + ForceTCP bool + PreferUDP bool + TLSConfig *tls.Config + TLSServerName string + Expire *time.Duration + MaxConcurrent *int64 + Policy string + TapPlugin *dnstap.Dnstap +} + // New returns a new Forward. func New() *Forward { f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcInterval, opts: options{forceTCP: false, preferUDP: false, hcRecursionDesired: true}} return f } +// NewWithConfig returns a new Forward configured by the provided +// ForwardConfig. +func NewWithConfig(config ForwardConfig) (*Forward, error) { + f := New() + if config.From != "" { + zones := plugin.Host(config.From).NormalizeExact() + f.from = zones[0] // there can only be one here, won't work with non-octet reverse + + if len(zones) > 1 { + log.Warningf("Unsupported CIDR notation: '%s' expands to multiple zones. Using only '%s'.", config.From, f.from) + } + } + for i := 0; i < len(config.Except); i++ { + f.ignored = append(f.ignored, plugin.Host(config.Except[i]).NormalizeExact()...) + } + if config.MaxFails != nil { + f.maxfails = *config.MaxFails + } + if config.HealthCheck != nil { + if *config.HealthCheck < 0 { + return nil, fmt.Errorf("health_check can't be negative: %s", *config.HealthCheck) + } + f.hcInterval = *config.HealthCheck + } + f.opts.hcRecursionDesired = !config.HealthCheckNoRec + f.opts.forceTCP = config.ForceTCP + f.opts.preferUDP = config.PreferUDP + if config.TLSConfig != nil { + f.tlsConfig = config.TLSConfig + } + f.tlsServerName = config.TLSServerName + if f.tlsServerName != "" { + f.tlsConfig.ServerName = f.tlsServerName + } + if config.Expire != nil { + f.expire = *config.Expire + if *config.Expire < 0 { + return nil, fmt.Errorf("expire can't be negative: %s", *config.Expire) + } + } + if config.MaxConcurrent != nil { + if *config.MaxConcurrent < 0 { + return f, fmt.Errorf("max_concurrent can't be negative: %d", *config.MaxConcurrent) + } + f.ErrLimitExceeded = fmt.Errorf("concurrent queries exceeded maximum %d", *config.MaxConcurrent) + f.maxConcurrent = *config.MaxConcurrent + } + if config.Policy != "" { + switch config.Policy { + case "random": + f.p = &random{} + case "round_robin": + f.p = &roundRobin{} + case "sequential": + f.p = &sequential{} + default: + return f, fmt.Errorf("unknown policy '%s'", config.Policy) + } + } + f.tapPlugin = config.TapPlugin + + toHosts, err := parse.HostPortOrFile(config.To...) + if err != nil { + return f, err + } + + transports := make([]string, len(toHosts)) + allowedTrans := map[string]bool{"dns": true, "tls": true} + for i, host := range toHosts { + trans, h := parse.Transport(host) + + if !allowedTrans[trans] { + return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) + } + p := NewProxy(h, trans) + f.proxies = append(f.proxies, p) + transports[i] = trans + } + + // Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake + // in upcoming connections to the same TLS server. + f.tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(len(f.proxies)) + + for i := range f.proxies { + // Only set this for proxies that need it. + if transports[i] == transport.TLS { + f.proxies[i].SetTLSConfig(f.tlsConfig) + } + f.proxies[i].SetExpire(f.expire) + f.proxies[i].health.SetRecursionDesired(f.opts.hcRecursionDesired) + + } + return f, nil +} + // SetProxy appends p to the proxy list and starts healthchecking. func (f *Forward) SetProxy(p *Proxy) { f.proxies = append(f.proxies, p) diff --git a/plugin/forward/forward_test.go b/plugin/forward/forward_test.go index b0ef47ba9..1159e0a85 100644 --- a/plugin/forward/forward_test.go +++ b/plugin/forward/forward_test.go @@ -1,7 +1,13 @@ package forward import ( + "crypto/tls" + "fmt" + "reflect" "testing" + "time" + + "github.com/coredns/coredns/plugin/dnstap" ) func TestList(t *testing.T) { @@ -22,3 +28,281 @@ func TestList(t *testing.T) { } } } + +func TestNewWithConfig(t *testing.T) { + expectedExcept := []string{"foo.com.", "example.com."} + expectedMaxFails := uint32(5) + expectedHealthCheck := 5 * time.Second + expectedServerName := "test" + expectedExpire := 20 * time.Second + expectedMaxConcurrent := int64(5) + expectedDnstap := dnstap.Dnstap{} + + f, err := NewWithConfig(ForwardConfig{ + From: "test", + To: []string{"1.2.3.4:3053", "tls://4.5.6.7"}, + Except: []string{"FOO.com", "example.com"}, + MaxFails: &expectedMaxFails, + HealthCheck: &expectedHealthCheck, + HealthCheckNoRec: true, + ForceTCP: true, + PreferUDP: true, + TLSConfig: &tls.Config{NextProtos: []string{"some-proto"}}, + TLSServerName: expectedServerName, + Expire: &expectedExpire, + MaxConcurrent: &expectedMaxConcurrent, + TapPlugin: &expectedDnstap, + }) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if f.from != "test." { + t.Fatalf("Expected from to be %s, got: %s", "test.", f.from) + } + + if len(f.proxies) != 2 { + t.Fatalf("Expected proxies to have len of %d, got: %d", 2, len(f.proxies)) + } + + if f.proxies[0].addr != "1.2.3.4:3053" { + t.Fatalf("Expected proxy to have addr of %s, got: %s", "1.2.3.4:3053", f.proxies[0].addr) + } + + if f.proxies[1].addr != "4.5.6.7:853" { + t.Fatalf("Expected proxy to have addr of %s, got: %s", "4.5.6.7:853", f.proxies[1].addr) + } + + if !reflect.DeepEqual(f.ignored, expectedExcept) { + t.Fatalf("Expected ignored to consist of %#v, got: %#v", expectedExcept, f.ignored) + } + + if f.maxfails != 5 { + t.Fatalf("Expected maxfails to be %d, got: %d", expectedMaxFails, f.maxfails) + } + + if f.hcInterval != 5*time.Second { + t.Fatalf("Expected hcInterval to be %s, got: %s", expectedHealthCheck, f.hcInterval) + } + + if f.opts.hcRecursionDesired { + t.Fatalf("Expected hcRecursionDesired to be false") + } + + if !f.opts.forceTCP { + t.Fatalf("Expected forceTCP to be true") + } + + if !f.opts.preferUDP { + t.Fatalf("Expected preferUDP to be true") + } + + if len(f.tlsConfig.NextProtos) != 1 || f.tlsConfig.NextProtos[0] != "some-proto" { + t.Fatalf("Expected tlsConfig to have NextProtos to consist of %s, got: %s", "some-proto", f.tlsConfig.NextProtos) + } + + if f.tlsConfig.ServerName != expectedServerName { + t.Fatalf("Expected tlsConfig to have ServerName to be %s, got: %s", expectedServerName, f.tlsConfig.ServerName) + } + + if f.tlsServerName != "test" { + t.Fatalf("Expected tlsSeverName to be %s, got: %s", expectedServerName, f.tlsServerName) + } + + if f.expire != 20*time.Second { + t.Fatalf("Expected expire to be %s, got: %s", expectedExpire, f.expire) + } + + if f.ErrLimitExceeded == nil || f.ErrLimitExceeded.Error() != "concurrent queries exceeded maximum 5" { + t.Fatalf("Expected ErrLimitExceeded to be %s, got: %s", "concurrent queries exceeded maximum 5", f.ErrLimitExceeded) + } + + if f.maxConcurrent != 5 { + t.Fatalf("Expected maxConcurrent to be %d, got: %d", 5, f.maxConcurrent) + } + + if fmt.Sprintf("%T", f.tlsConfig.ClientSessionCache) != "*tls.lruSessionCache" { + t.Fatalf("Expected tlsConfig.ClientSessionCache to be type %s, got: %T", "*tls.lruSessionCache", f.tlsConfig.ClientSessionCache) + } + + if f.proxies[0].transport.expire != f.expire { + t.Fatalf("Expected proxy.transport.expire to be %s, got: %s", f.expire, f.proxies[0].transport.expire) + } + + if f.proxies[1].transport.expire != f.expire { + t.Fatalf("Expected proxy.transport.expire to be %s, got: %s", f.expire, f.proxies[1].transport.expire) + } + + if f.proxies[0].health.GetRecursionDesired() != f.opts.hcRecursionDesired { + t.Fatalf("Expected proxy.health.GetRecursionDesired to be %t, got: %t", f.opts.hcRecursionDesired, f.proxies[0].health.GetRecursionDesired()) + } + + if f.proxies[1].health.GetRecursionDesired() != f.opts.hcRecursionDesired { + t.Fatalf("Expected proxy.health.GetRecursionDesired to be %t, got: %t", f.opts.hcRecursionDesired, f.proxies[1].health.GetRecursionDesired()) + } + + if f.proxies[0].transport.tlsConfig == f.tlsConfig { + t.Fatalf("Expected proxy.transport.tlsConfig to be nil, got: %#v", f.proxies[0].transport.tlsConfig) + } + + if f.proxies[1].transport.tlsConfig != f.tlsConfig { + t.Fatalf("Expected proxy.transport.tlsConfig to be %#v, got: %#v", f.tlsConfig, f.proxies[1].transport.tlsConfig) + } + + if f.tapPlugin != &expectedDnstap { + t.Fatalf("Expcted tapPlugin to be %p, got: %p", &expectedDnstap, f.tapPlugin) + } +} + +func TestNewWithConfigNegativeHealthCheck(t *testing.T) { + healthCheck, _ := time.ParseDuration("-5s") + + _, err := NewWithConfig(ForwardConfig{ + To: []string{"1.2.3.4:3053", "4.5.6.7"}, + HealthCheck: &healthCheck, + }) + if err == nil || err.Error() != "health_check can't be negative: -5s" { + t.Fatalf("Expected error to be %s, got: %s", "health_check can't be negative: -5s", err) + } +} + +func TestNewWithConfigNegativeExpire(t *testing.T) { + expire, _ := time.ParseDuration("-5s") + + _, err := NewWithConfig(ForwardConfig{ + To: []string{"1.2.3.4:3053", "4.5.6.7"}, + Expire: &expire, + }) + if err == nil || err.Error() != "expire can't be negative: -5s" { + t.Fatalf("Expected error to be %s, got: %s", "expire can't be negative: -5s", err) + } +} + +func TestNewWithConfigNegativeMaxConcurrent(t *testing.T) { + maxConcurrent := int64(-5) + + _, err := NewWithConfig(ForwardConfig{ + To: []string{"1.2.3.4:3053", "4.5.6.7"}, + MaxConcurrent: &maxConcurrent, + }) + if err == nil || err.Error() != "max_concurrent can't be negative: -5" { + t.Fatalf("Expected error to be %s, got: %s", "max_concurrent can't be negative: -5", err) + } +} + +func TestNewWithConfigPolicy(t *testing.T) { + config := ForwardConfig{ + To: []string{"1.2.3.4:3053", "4.5.6.7"}, + } + + config.Policy = "random" + f, err := NewWithConfig(config) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if _, ok := f.p.(*random); !ok { + t.Fatalf("Expect p to be of type %s, got: %T", "random", f.p) + } + + config.Policy = "round_robin" + f, err = NewWithConfig(config) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if _, ok := f.p.(*roundRobin); !ok { + t.Fatalf("Expect p to be of type %s, got: %T", "roundRobin", f.p) + } + + config.Policy = "sequential" + f, err = NewWithConfig(config) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if _, ok := f.p.(*sequential); !ok { + t.Fatalf("Expect p to be of type %s, got: %T", "sequential", f.p) + } + + config.Policy = "invalid_policy" + _, err = NewWithConfig(config) + if err == nil { + t.Fatalf("Expected error %s, got: %s", "unknown policy 'invalid_policy'", err) + } +} + +func TestNewWithConfigServerNameDefault(t *testing.T) { + f, err := NewWithConfig(ForwardConfig{ + To: []string{"1.2.3.4"}, + TLSConfig: &tls.Config{ServerName: "some-server-name"}, + }) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if f.tlsConfig.ServerName != "some-server-name" { + t.Fatalf("Expect tlsConfig.ServerName to be %s, got: %s", "some-server-name", f.tlsConfig.ServerName) + } +} + +func TestNewWithConfigWithDefaults(t *testing.T) { + f, err := NewWithConfig(ForwardConfig{ + To: []string{"1.2.3.4"}, + }) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if f.from != "." { + t.Fatalf("Expected from to be %s, got: %s", ".", f.from) + } + + if f.ignored != nil { + t.Fatalf("Expected ignored to be nil but was %#v", f.ignored) + } + + if f.maxfails != 2 { + t.Fatalf("Expected maxfails to be %d, got: %d", 2, f.maxfails) + } + + if f.hcInterval != 500*time.Millisecond { + t.Fatalf("Expected hcInterval to be %s, got: %s", 500*time.Millisecond, f.hcInterval) + } + + if !f.opts.hcRecursionDesired { + t.Fatalf("Expected hcRecursionDesired to be true") + } + + if f.opts.forceTCP { + t.Fatalf("Expected forceTCP to be false") + } + + if f.opts.preferUDP { + t.Fatalf("Expected preferUDP to be false") + } + + if f.tlsConfig == nil { + t.Fatalf("Expected tlsConfig to be non nil") + } + + if f.tlsServerName != "" { + t.Fatalf("Expected tlsServerName to be empty") + } + + if f.expire != defaultExpire { + t.Fatalf("Expected expire to be %s, got: %s", defaultExpire, f.expire) + } + + if f.ErrLimitExceeded != nil { + t.Fatalf("Expected ErrLimitExceeded to be nil") + } + + if f.maxConcurrent != 0 { + t.Fatalf("Expected maxConcurrent to be %d, got: %d", 0, f.maxConcurrent) + } + + if _, ok := f.p.(*random); !ok { + t.Fatalf("Expect p to be of type %s, got: %T", "random", f.p) + } +} diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 1f88daba5..baf80f12b 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -1,8 +1,6 @@ package forward import ( - "crypto/tls" - "errors" "fmt" "strconv" "time" @@ -11,9 +9,7 @@ import ( "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/dnstap" - "github.com/coredns/coredns/plugin/pkg/parse" pkgtls "github.com/coredns/coredns/plugin/pkg/tls" - "github.com/coredns/coredns/plugin/pkg/transport" ) func init() { plugin.Register("forward", setup) } @@ -87,90 +83,46 @@ func parseForward(c *caddy.Controller) (*Forward, error) { } func parseStanza(c *caddy.Controller) (*Forward, error) { - f := New() + cfg := ForwardConfig{} - if !c.Args(&f.from) { - return f, c.ArgErr() - } - origFrom := f.from - zones := plugin.Host(f.from).NormalizeExact() - f.from = zones[0] // there can only be one here, won't work with non-octet reverse - - if len(zones) > 1 { - log.Warningf("Unsupported CIDR notation: '%s' expands to multiple zones. Using only '%s'.", origFrom, f.from) + if !c.Args(&cfg.From) { + return nil, c.ArgErr() } - to := c.RemainingArgs() - if len(to) == 0 { - return f, c.ArgErr() - } - - toHosts, err := parse.HostPortOrFile(to...) - if err != nil { - return f, err - } - - transports := make([]string, len(toHosts)) - allowedTrans := map[string]bool{"dns": true, "tls": true} - for i, host := range toHosts { - trans, h := parse.Transport(host) - - if !allowedTrans[trans] { - return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) - } - p := NewProxy(h, trans) - f.proxies = append(f.proxies, p) - transports[i] = trans + cfg.To = c.RemainingArgs() + if len(cfg.To) == 0 { + return nil, c.ArgErr() } for c.NextBlock() { - if err := parseBlock(c, f); err != nil { - return f, err + if err := parseBlock(c, &cfg); err != nil { + return nil, err } } - if f.tlsServerName != "" { - f.tlsConfig.ServerName = f.tlsServerName - } - - // Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake - // in upcoming connections to the same TLS server. - f.tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(len(f.proxies)) - - for i := range f.proxies { - // Only set this for proxies that need it. - if transports[i] == transport.TLS { - f.proxies[i].SetTLSConfig(f.tlsConfig) - } - f.proxies[i].SetExpire(f.expire) - f.proxies[i].health.SetRecursionDesired(f.opts.hcRecursionDesired) - } - - return f, nil + return NewWithConfig(cfg) } -func parseBlock(c *caddy.Controller, f *Forward) error { +func parseBlock(c *caddy.Controller, cfg *ForwardConfig) error { switch c.Val() { case "except": - ignore := c.RemainingArgs() - if len(ignore) == 0 { + cfg.Except = c.RemainingArgs() + if len(cfg.Except) == 0 { return c.ArgErr() } - for i := 0; i < len(ignore); i++ { - f.ignored = append(f.ignored, plugin.Host(ignore[i]).NormalizeExact()...) - } case "max_fails": if !c.NextArg() { return c.ArgErr() } - n, err := strconv.Atoi(c.Val()) + n, err := strconv.ParseInt(c.Val(), 10, 32) if err != nil { return err } if n < 0 { return fmt.Errorf("max_fails can't be negative: %d", n) } - f.maxfails = uint32(n) + maxFails := uint32(n) + cfg.MaxFails = &maxFails case "health_check": if !c.NextArg() { return c.ArgErr() @@ -179,15 +131,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if err != nil { return err } - if dur < 0 { - return fmt.Errorf("health_check can't be negative: %d", dur) - } - f.hcInterval = dur + cfg.HealthCheck = &dur for c.NextArg() { switch hcOpts := c.Val(); hcOpts { case "no_rec": - f.opts.hcRecursionDesired = false + cfg.HealthCheckNoRec = true default: return fmt.Errorf("health_check: unknown option %s", hcOpts) } @@ -197,12 +146,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if c.NextArg() { return c.ArgErr() } - f.opts.forceTCP = true + cfg.ForceTCP = true case "prefer_udp": if c.NextArg() { return c.ArgErr() } - f.opts.preferUDP = true + cfg.PreferUDP = true case "tls": args := c.RemainingArgs() if len(args) > 3 { @@ -213,12 +162,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if err != nil { return err } - f.tlsConfig = tlsConfig + cfg.TLSConfig = tlsConfig case "tls_servername": if !c.NextArg() { return c.ArgErr() } - f.tlsServerName = c.Val() + cfg.TLSServerName = c.Val() case "expire": if !c.NextArg() { return c.ArgErr() @@ -227,24 +176,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if err != nil { return err } - if dur < 0 { - return fmt.Errorf("expire can't be negative: %s", dur) - } - f.expire = dur + cfg.Expire = &dur case "policy": if !c.NextArg() { return c.ArgErr() } - switch x := c.Val(); x { - case "random": - f.p = &random{} - case "round_robin": - f.p = &roundRobin{} - case "sequential": - f.p = &sequential{} - default: - return c.Errf("unknown policy '%s'", x) - } + cfg.Policy = c.Val() case "max_concurrent": if !c.NextArg() { return c.ArgErr() @@ -253,11 +190,8 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if err != nil { return err } - if n < 0 { - return fmt.Errorf("max_concurrent can't be negative: %d", n) - } - f.ErrLimitExceeded = errors.New("concurrent queries exceeded maximum " + c.Val()) - f.maxConcurrent = int64(n) + maxConcurrent := int64(n) + cfg.MaxConcurrent = &maxConcurrent default: return c.Errf("unknown property '%s'", c.Val()) diff --git a/plugin/forwardcrd/Makefile b/plugin/forwardcrd/Makefile new file mode 100644 index 000000000..b777e11c6 --- /dev/null +++ b/plugin/forwardcrd/Makefile @@ -0,0 +1,25 @@ +.PHONY: generate +generate: controller-gen + $(CONTROLLER_GEN) \ + object \ + paths="./apis/..." + $(CONTROLLER_GEN) \ + crd \ + paths="./apis/..." \ + output:crd:artifacts:config=manifests/crds/ + +.PHONY: controller-gen +controller-gen: +ifeq (, $(shell which controller-gen)) + @{ \ + set -e ;\ + CONTROLLER_GEN_TMP_DIR=$$(mktemp -d) ;\ + cd $$CONTROLLER_GEN_TMP_DIR ;\ + go mod init tmp ;\ + go get sigs.k8s.io/controller-tools/cmd/controller-gen@v0.4.1 ;\ + rm -rf $$CONTROLLER_GEN_TMP_DIR ;\ + } +CONTROLLER_GEN=$(shell go env GOPATH)/bin/controller-gen +else +CONTROLLER_GEN=$(shell which controller-gen) +endif diff --git a/plugin/forwardcrd/README.md b/plugin/forwardcrd/README.md new file mode 100644 index 000000000..c2073f28e --- /dev/null +++ b/plugin/forwardcrd/README.md @@ -0,0 +1,256 @@ +# forwardcrd + +## Name + +*forwardcrd* - enables proxying DNS messages to upstream resolvers by reading +the `Forward` CRD from a Kubernetes cluster + +## Description + +The *forwardcrd* plugin is used to dynamically configure stub-domains by +reading a `Forward` CRD within a Kubernetes cluster. + +See [Configuring Private DNS Zones and Upstream Nameservers in +Kubernetes](https://kubernetes.io/blog/2017/04/configuring-private-dns-zones-upstream-nameservers-kubernetes/) +for a description of stub-domains within Kubernetes. + +This plugin can only be used once per Server Block. + +## Security + +This plugin gives users of Kubernetes another avenue of modifying the CoreDNS +server other than the `coredns` configmap. Therefore, it is important that you +limit the RBAC and the `namespace` the plugin reads from to reduce the surface +area a malicious actor can use. Ideally, the level of access to create `Forward` +resources is at the same level as the access to the `coredns` configmap. + +## Syntax + +~~~ +forwardcrd [ZONES...] +~~~ + +With only the plugin specified, the *forwardcrd* plugin will default to the +zone specified in the server's block. It will allow any `Forward` resource that +matches or includes the zone as a suffix. If **ZONES** are specified it allows +any zone listed as a suffix. + +``` +forwardcrd [ZONES...] { + endpoint URL + tls CERT KEY CACERT + kubeconfig KUBECONFIG [CONTEXT] + namespace [NAMESPACE] +} +``` + +* `endpoint` specifies the **URL** for a remote k8s API endpoint. If omitted, + it will connect to k8s in-cluster using the cluster service account. +* `tls` **CERT** **KEY** **CACERT** are the TLS cert, key and the CA cert file + names for remote k8s connection. This option is ignored if connecting + in-cluster (i.e. endpoint is not specified). +* `kubeconfig` **KUBECONFIG [CONTEXT]** authenticates the connection to a remote + k8s cluster using a kubeconfig file. **[CONTEXT]** is optional, if not set, + then the current context specified in kubeconfig will be used. It supports + TLS, username and password, or token-based authentication. This option is + ignored if connecting in-cluster (i.e., the endpoint is not specified). +* `namespace` **[NAMESPACE]** only reads `Forward` resources from the namespace + listed. If this option is omitted then it will read from the default + namespace, `kube-system`. If this option is specified without any namespaces + listed it will read from all namespaces. **Note**: It is recommended to limit + the namespace (e.g to `kube-system`) because this can be potentially misused. + It is ideal to keep the level of write access similar to the `coredns` + configmap in the `kube-system` namespace. + +## Ready + +This plugin reports readiness to the ready plugin. This will happen after it has +synced to the Kubernetes API. + +## Ordering + +Forward behavior can be defined in three ways, via a Server Block, via the +*forwardcrd* plugin, and via the *forward* plugin. If more than one of these +methods is employed and a query falls within the zone of more than one, CoreDNS +selects which one to use based on the following precedence: +Corefile Server Block -> *forwardcrd* plugin -> *forward* plugin. + +When `Forward` CRDs and Server Blocks define stub domains that are used, +domains defined in the Corefile take precedence (in the event of zone overlap). +e.g. if the +domain `example.com` is defined in the Corefile as a stub domain, and a +`Forward` CRD record defined for `sub.example.com`, then `sub.example.com` would +get forwarded to the upstream defined in the Corefile, not the `Forward` CRD. + +When using *forwardcrd* and *forward* plugins in the same Server Block, `Forward` CRDs +take precedence over the *forward* plugin defined in the same Server Block. +e.g. if a `Forward` CRD is defined for `.`, then no queries would be +forwarded to the upstream defined in the *forward* plugin of the same Server Block. + +## Metrics + +`Forward` CRD metrics are all labeled in a single zone (the zone of the enclosing +Server Block). + +## Forward CRD + +The `Forward` CRD has the following spec properties: + +* **from** is the base domain to match for the request to be forwarded. +* **to** are the destination endpoints to forward to. The **to** syntax allows + you to specify a protocol, `tls://9.9.9.9` or `dns://` (or no protocol) for + plain DNS. The number of upstreams is limited to 15. + +## Examples + +The following is an example of how you might modify the `coredns` ConfigMap of +your cluster to enable the *forwardcrd* plugin. The following configuration will +watch and read any `Forward` CRD records in the `kube-system` namespace for +*any* zone name. This means you are able to able to create a `Forward` CRD +record that overlaps an existing zone. + +```yaml +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: coredns + namespace: kube-system +data: + Corefile: | + .:53 { + errors + health { + lameduck 5s + } + ready + kubernetes cluster.local in-addr.arpa ip6.arpa { + pods insecure + fallthrough in-addr.arpa ip6.arpa + ttl 30 + } + forwardcrd + prometheus :9153 + forward . /etc/resolv.conf + cache 30 + loop + reload + loadbalance + } +``` + +When you want to enable the *forwardcrd* plugin, you will need to apply the CRD +as well. + +``` +kubectl apply -f ./manifests/crds/coredns.io_forwards.yaml +``` + +Also note that the `ClusterRole` for CoreDNS must include: +In addition, you will need to modify the `system:coredns` ClusterRole in the +`kube-system` namespace to include the following: + +```yaml +rules: +- apiGroups: + - coredns.io + resources: + - forwards + verbs: + - list + - watch +``` + +This will allow CoreDNS to watch and list `Forward` CRD records from the +Kubernetes API. + +Now you can configure stubdomains by creating `Forward` CRD records in the +`kube-system` namespace. + +For example, if a cluster operator has a [Consul](https://www.consul.io/) domain +server located at 10.150.0.1, and all Consul names have the suffix +.consul.local. To configure this, the cluster administrator creates the +following record: + +```yaml +--- +apiVersion: coredns.io/v1alpha1 +kind: Forward +metadata: + name: consul-local + namespace: kube-system +spec: + from: consul.local + to: + - 10.150.0.1 +``` + +### Additional examples + +Allow `Forward` resources to be created for any zone and only read `Forward` +resources from the `kube-system` namespace: + +~~~ txt +. { + forwardcrd +} +~~~ + +Allow `Forward` resources to be created for the `.local` zone and only read +`Forward` resources from the `kube-system` namespace: + + +~~~ txt +. { + forwardcrd local +} +~~~ + +or: + +~~~ txt +local { + forwardcrd +} +~~~ + +Only read `Forward` resources from the `dns-system` namespace: + +~~~ txt +. { + forwardcrd { + namespace dns-system + } +} +~~~ + +Read `Forward` resources from all namespaces: + +~~~ txt +. { + forwardcrd { + namespace + } +} +~~~ + +Connect to Kubernetes with CoreDNS running outside the cluster: + +~~~ txt +. { + forwardcrd { + endpoint https://k8s-endpoint:8443 + tls cert key cacert + } +} +~~~ + +or: + +~~~ txt +. { + forwardcrd { + kubeconfig ./kubeconfig + } +} +~~~ diff --git a/plugin/forwardcrd/apis/coredns/v1alpha1/forward_types.go b/plugin/forwardcrd/apis/coredns/v1alpha1/forward_types.go new file mode 100644 index 000000000..dabdbe70d --- /dev/null +++ b/plugin/forwardcrd/apis/coredns/v1alpha1/forward_types.go @@ -0,0 +1,39 @@ +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// ForwardSpec represents the spec of a Forward +type ForwardSpec struct { + From string `json:"from,omitempty"` + To []string `json:"to,omitempty"` +} + +// ForwardStatus represents the status of a Forward +type ForwardStatus struct { +} + +// +kubebuilder:object:root=true +// +kubebuilder:printcolumn:name="From",type=string,JSONPath=`.spec.from` +// +kubebuilder:printcolumn:name="To",type=string,JSONPath=`.spec.to` + +// Forward represents a zone that should have its DNS requests forwarded to an +// upstream DNS server within CoreDNS +type Forward struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec ForwardSpec `json:"spec,omitempty"` + Status ForwardStatus `json:"status,omitempty"` +} + +// +kubebuilder:object:root=true + +// ForwardList represents a list of Forwards +type ForwardList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + + Items []Forward `json:"items"` +} diff --git a/plugin/forwardcrd/apis/coredns/v1alpha1/groupversion_info.go b/plugin/forwardcrd/apis/coredns/v1alpha1/groupversion_info.go new file mode 100644 index 000000000..2432d0300 --- /dev/null +++ b/plugin/forwardcrd/apis/coredns/v1alpha1/groupversion_info.go @@ -0,0 +1,36 @@ +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +// +kubebuilder:object:generate=true +// +groupName=coredns.io + +var ( + // GroupVersion is group version used to register these objects + GroupVersion = schema.GroupVersion{Group: "coredns.io", Version: "v1alpha1"} + + // SchemeBuilder is used to add go types to the GroupVersionKind scheme + SchemeBuilder = &runtime.SchemeBuilder{} + + // AddToScheme adds the types in this group-version to the given scheme. + AddToScheme = SchemeBuilder.AddToScheme +) + +func init() { + SchemeBuilder.Register(addKnownTypes) +} + +func addKnownTypes(scheme *runtime.Scheme) error { + scheme.AddKnownTypes(GroupVersion, + &Forward{}, + &ForwardList{}, + ) + + metav1.AddToGroupVersion(scheme, GroupVersion) + + return nil +} diff --git a/plugin/forwardcrd/apis/coredns/v1alpha1/zz_generated.deepcopy.go b/plugin/forwardcrd/apis/coredns/v1alpha1/zz_generated.deepcopy.go new file mode 100644 index 000000000..914148871 --- /dev/null +++ b/plugin/forwardcrd/apis/coredns/v1alpha1/zz_generated.deepcopy.go @@ -0,0 +1,103 @@ +// +build !ignore_autogenerated + +// Code generated by controller-gen. DO NOT EDIT. + +package v1alpha1 + +import ( + "k8s.io/apimachinery/pkg/runtime" +) + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Forward) DeepCopyInto(out *Forward) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + in.Spec.DeepCopyInto(&out.Spec) + out.Status = in.Status +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Forward. +func (in *Forward) DeepCopy() *Forward { + if in == nil { + return nil + } + out := new(Forward) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *Forward) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ForwardList) DeepCopyInto(out *ForwardList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]Forward, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ForwardList. +func (in *ForwardList) DeepCopy() *ForwardList { + if in == nil { + return nil + } + out := new(ForwardList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *ForwardList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ForwardSpec) DeepCopyInto(out *ForwardSpec) { + *out = *in + if in.To != nil { + in, out := &in.To, &out.To + *out = make([]string, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ForwardSpec. +func (in *ForwardSpec) DeepCopy() *ForwardSpec { + if in == nil { + return nil + } + out := new(ForwardSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ForwardStatus) DeepCopyInto(out *ForwardStatus) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ForwardStatus. +func (in *ForwardStatus) DeepCopy() *ForwardStatus { + if in == nil { + return nil + } + out := new(ForwardStatus) + in.DeepCopyInto(out) + return out +} diff --git a/plugin/forwardcrd/controller.go b/plugin/forwardcrd/controller.go new file mode 100644 index 000000000..b35e77b62 --- /dev/null +++ b/plugin/forwardcrd/controller.go @@ -0,0 +1,238 @@ +package forwardcrd + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/dnstap" + "github.com/coredns/coredns/plugin/forward" + corednsv1alpha1 "github.com/coredns/coredns/plugin/forwardcrd/apis/coredns/v1alpha1" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/workqueue" +) + +const defaultResyncPeriod = 0 + +type forwardCRDController interface { + Run(threads int) + HasSynced() bool + Stop() error +} + +type forwardCRDControl struct { + client dynamic.Interface + scheme *runtime.Scheme + forwardController cache.Controller + forwardLister cache.Store + workqueue workqueue.RateLimitingInterface + pluginMap *PluginInstanceMap + instancer pluginInstancer + tapPlugin *dnstap.Dnstap + namespace string + + // stopLock is used to enforce only a single call to Stop is active. + // Needed because we allow stopping through an http endpoint and + // allowing concurrent stoppers leads to stack traces. + stopLock sync.Mutex + shutdown bool + stopCh chan struct{} +} + +type lifecyclePluginHandler interface { + plugin.Handler + OnStartup() error + OnShutdown() error +} + +type pluginInstancer func(forward.ForwardConfig) (lifecyclePluginHandler, error) + +func newForwardCRDController(ctx context.Context, client dynamic.Interface, scheme *runtime.Scheme, namespace string, pluginMap *PluginInstanceMap, instancer pluginInstancer) forwardCRDController { + controller := forwardCRDControl{ + client: client, + scheme: scheme, + stopCh: make(chan struct{}), + namespace: namespace, + pluginMap: pluginMap, + instancer: instancer, + workqueue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "ForwardCRD"), + } + + controller.forwardLister, controller.forwardController = cache.NewInformer( + &cache.ListWatch{ + ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + if namespace != "" { + return controller.client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")).Namespace(namespace).List(ctx, options) + } + return controller.client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")).List(ctx, options) + }, + WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + if namespace != "" { + return controller.client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")).Namespace(namespace).Watch(ctx, options) + } + return controller.client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")).Watch(ctx, options) + }, + }, + &unstructured.Unstructured{}, + defaultResyncPeriod, + cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + key, err := cache.MetaNamespaceKeyFunc(obj) + if err == nil { + controller.workqueue.Add(key) + } + }, + UpdateFunc: func(oldObj, newObj interface{}) { + key, err := cache.MetaNamespaceKeyFunc(newObj) + if err == nil { + controller.workqueue.Add(key) + } + }, + DeleteFunc: func(obj interface{}) { + key, err := cache.DeletionHandlingMetaNamespaceKeyFunc(obj) + if err == nil { + controller.workqueue.Add(key) + } + }, + }, + ) + + return &controller +} + +// Run starts the controller. Threads is the number of workers that can process +// work on the workqueue in parallel. +func (d *forwardCRDControl) Run(threads int) { + defer utilruntime.HandleCrash() + defer d.workqueue.ShutDown() + + go d.forwardController.Run(d.stopCh) + + if !cache.WaitForCacheSync(d.stopCh, d.forwardController.HasSynced) { + utilruntime.HandleError(errors.New("Timed out waiting for caches to sync")) + return + } + + for i := 0; i < threads; i++ { + go wait.Until(d.runWorker, time.Second, d.stopCh) + } + + <-d.stopCh + + // Shutdown all plugins + for _, plugin := range d.pluginMap.List() { + plugin.OnShutdown() + } +} + +// HasSynced returns true once the controller has completed an initial resource +// listing. +func (d *forwardCRDControl) HasSynced() bool { + return d.forwardController.HasSynced() +} + +// Stop stops the controller. +func (d *forwardCRDControl) Stop() error { + d.stopLock.Lock() + defer d.stopLock.Unlock() + + // Only try draining the workqueue if we haven't already. + if !d.shutdown { + close(d.stopCh) + d.shutdown = true + + return nil + } + + return fmt.Errorf("shutdown already in progress") +} + +func (d *forwardCRDControl) runWorker() { + for d.processNextItem() { + } +} + +func (d *forwardCRDControl) processNextItem() bool { + key, quit := d.workqueue.Get() + if quit { + return false + } + + defer d.workqueue.Done(key) + + err := d.sync(key.(string)) + if err != nil { + log.Errorf("Error syncing Forward %v: %v", key, err) + d.workqueue.AddRateLimited(key) + return true + } + + d.workqueue.Forget(key) + + return true +} + +func (d *forwardCRDControl) sync(key string) error { + obj, exists, err := d.forwardLister.GetByKey(key) + if err != nil { + return err + } + + if !exists { + plugin := d.pluginMap.Delete(key) + if plugin != nil { + plugin.OnShutdown() + } + } else { + f, err := d.convertToForward(obj.(runtime.Object)) + if err != nil { + return err + } + forwardConfig := forward.ForwardConfig{ + From: f.Spec.From, + To: f.Spec.To, + TapPlugin: d.tapPlugin, + } + plugin, err := d.instancer(forwardConfig) + if err != nil { + return err + } + err = plugin.OnStartup() + if err != nil { + return err + } + oldPlugin, updated := d.pluginMap.Upsert(key, f.Spec.From, plugin) + if updated { + oldPlugin.OnShutdown() + } + } + + return nil +} + +func (d *forwardCRDControl) convertToForward(obj runtime.Object) (*corednsv1alpha1.Forward, error) { + unstructured, ok := obj.(*unstructured.Unstructured) + if !ok { + return nil, fmt.Errorf("object was not Unstructured") + } + + switch unstructured.GetKind() { + case "Forward": + forward := &corednsv1alpha1.Forward{} + err := d.scheme.Convert(unstructured, forward, nil) + return forward, err + default: + return nil, fmt.Errorf("unsupported object type: %T", unstructured) + } +} diff --git a/plugin/forwardcrd/controller_test.go b/plugin/forwardcrd/controller_test.go new file mode 100644 index 000000000..8531eacc2 --- /dev/null +++ b/plugin/forwardcrd/controller_test.go @@ -0,0 +1,333 @@ +package forwardcrd + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/coredns/coredns/plugin/forward" + corednsv1alpha1 "github.com/coredns/coredns/plugin/forwardcrd/apis/coredns/v1alpha1" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/dynamic/fake" +) + +func TestCreateForward(t *testing.T) { + controller, client, testPluginInstancer, pluginInstanceMap := setupControllerTestcase(t, "") + forward := &corednsv1alpha1.Forward{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dns-zone", + Namespace: "default", + }, + Spec: corednsv1alpha1.ForwardSpec{ + From: "crd.test", + To: []string{"127.0.0.2", "127.0.0.3"}, + }, + } + + _, err := client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")). + Namespace("default"). + Create(context.Background(), mustForwardToUnstructured(forward), metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + err = wait.Poll(time.Second, time.Second*5, func() (bool, error) { + return testPluginInstancer.NewWithConfigCallCount() == 1, nil + }) + if err != nil { + t.Fatalf("Expected plugin instance to have been called: %s", err) + } + + handler := testPluginInstancer.NewWithConfigArgsForCall(0) + if handler.ReceivedConfig.From != "crd.test" { + t.Fatalf("Expected plugin to be created for zone: %s but was: %s", "crd.test", handler.ReceivedConfig.From) + } + + if len(handler.ReceivedConfig.To) != 2 { + t.Fatalf("Expected plugin to contain exactly 2 servers to forward to but contains: %#v", handler.ReceivedConfig.To) + } + + if handler.ReceivedConfig.To[0] != "127.0.0.2" { + t.Fatalf("Expected plugin to be created to forward to: %s but was: %s", "127.0.0.2", handler.ReceivedConfig.To[0]) + } + + if handler.ReceivedConfig.To[1] != "127.0.0.3" { + t.Fatalf("Expected plugin to be created to forward to: %s but was: %s", "127.0.0.3", handler.ReceivedConfig.To[1]) + } + + pluginHandler, ok := pluginInstanceMap.Get("crd.test") + if !ok { + t.Fatal("Expected plugin lookup to succeed") + } + + if pluginHandler != handler { + t.Fatalf("Exepcted plugin lookup to match what the instancer provided: %#v but was %#v", handler, pluginHandler) + } + + if testPluginInstancer.testPluginHandlers[0].OnStartupCallCount() != 1 { + t.Fatalf("Expected plugin OnStartup to have been called once, but got: %d", testPluginInstancer.testPluginHandlers[0].OnStartupCallCount()) + } + + if err := controller.Stop(); err != nil { + t.Fatalf("Expected no error: %s", err) + } + + err = wait.Poll(time.Second, time.Second*5, func() (bool, error) { + return testPluginInstancer.testPluginHandlers[0].OnShutdownCallCount() == 1, nil + }) + if err != nil { + t.Fatalf("Expected plugin OnShutdown to have been called once, but got: %d", testPluginInstancer.testPluginHandlers[0].OnShutdownCallCount()) + } +} + +func TestUpdateForward(t *testing.T) { + controller, client, testPluginInstancer, pluginInstanceMap := setupControllerTestcase(t, "") + forward := &corednsv1alpha1.Forward{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dns-zone", + Namespace: "default", + }, + Spec: corednsv1alpha1.ForwardSpec{ + From: "crd.test", + To: []string{"127.0.0.2"}, + }, + } + + unstructuredForward, err := client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")). + Namespace("default"). + Create(context.Background(), mustForwardToUnstructured(forward), metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + err = wait.Poll(time.Second, time.Second*5, func() (bool, error) { + return testPluginInstancer.NewWithConfigCallCount() == 1, nil + }) + if err != nil { + t.Fatalf("Expected plugin instance to have been called: %s", err) + } + + forward = mustUnstructuredToForward(unstructuredForward) + forward.Spec.From = "other.test" + forward.Spec.To = []string{"127.0.0.3"} + + _, err = client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")). + Namespace("default"). + Update(context.Background(), mustForwardToUnstructured(forward), metav1.UpdateOptions{}) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + err = wait.Poll(time.Second, time.Second*5, func() (bool, error) { + return testPluginInstancer.NewWithConfigCallCount() == 2, nil + }) + if err != nil { + t.Fatalf("Expected plugin instance to have been called: %s", err) + } + + handler := testPluginInstancer.NewWithConfigArgsForCall(1) + if handler.ReceivedConfig.From != "other.test" { + t.Fatalf("Expected plugin to be created for zone: %s but was: %s", "other.test", handler.ReceivedConfig.From) + } + + if len(handler.ReceivedConfig.To) != 1 { + t.Fatalf("Expected plugin to contain exactly 1 server to forward to but contains: %#v", handler.ReceivedConfig.To) + } + + if handler.ReceivedConfig.To[0] != "127.0.0.3" { + t.Fatalf("Expected plugin to be created to forward to: %s but was: %s", "127.0.0.3", handler.ReceivedConfig.To[0]) + } + + pluginHandler, ok := pluginInstanceMap.Get("other.test") + if !ok { + t.Fatal("Expected plugin lookup to succeed") + } + + if pluginHandler != handler { + t.Fatalf("Exepcted plugin lookup to match what the instancer provided: %#v but was %#v", handler, pluginHandler) + } + + _, ok = pluginInstanceMap.Get("crd.test") + if ok { + t.Fatal("Expected lookup for crd.test to fail") + } + + if testPluginInstancer.testPluginHandlers[0].OnShutdownCallCount() != 1 { + t.Fatalf("Expected plugin OnShutdown to have been called once, but got: %d", testPluginInstancer.testPluginHandlers[0].OnShutdownCallCount()) + } + + if err := controller.Stop(); err != nil { + t.Fatalf("Expected no error: %s", err) + } +} + +func TestDeleteForward(t *testing.T) { + controller, client, testPluginInstancer, pluginInstanceMap := setupControllerTestcase(t, "") + forward := &corednsv1alpha1.Forward{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dns-zone", + Namespace: "default", + }, + Spec: corednsv1alpha1.ForwardSpec{ + From: "crd.test", + To: []string{"127.0.0.2"}, + }, + } + + _, err := client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")). + Namespace("default"). + Create(context.Background(), mustForwardToUnstructured(forward), metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + err = wait.Poll(time.Second, time.Second*5, func() (bool, error) { + return testPluginInstancer.NewWithConfigCallCount() == 1, nil + }) + if err != nil { + t.Fatalf("Expected plugin instance to have been called: %s", err) + } + + err = client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")). + Namespace("default"). + Delete(context.Background(), "test-dns-zone", metav1.DeleteOptions{}) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + err = wait.Poll(time.Second, time.Second*5, func() (bool, error) { + _, ok := pluginInstanceMap.Get("crd.test") + return !ok, nil + }) + if err != nil { + t.Fatalf("Expected lookup for crd.test to fail: %s", err) + } + + if testPluginInstancer.testPluginHandlers[0].OnShutdownCallCount() != 1 { + t.Fatalf("Expected plugin OnShutdown to have been called once, but got: %d", testPluginInstancer.testPluginHandlers[0].OnShutdownCallCount()) + } + + if err := controller.Stop(); err != nil { + t.Fatalf("Expected no error: %s", err) + } +} + +func TestForwardLimitNamespace(t *testing.T) { + controller, client, testPluginInstancer, pluginInstanceMap := setupControllerTestcase(t, "kube-system") + forward := &corednsv1alpha1.Forward{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dns-zone", + Namespace: "default", + }, + Spec: corednsv1alpha1.ForwardSpec{ + From: "crd.test", + To: []string{"127.0.0.2"}, + }, + } + + _, err := client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")). + Namespace("default"). + Create(context.Background(), mustForwardToUnstructured(forward), metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + kubeSystemForward := &corednsv1alpha1.Forward{ + ObjectMeta: metav1.ObjectMeta{ + Name: "system-dns-zone", + Namespace: "kube-system", + }, + Spec: corednsv1alpha1.ForwardSpec{ + From: "system.test", + To: []string{"127.0.0.3"}, + }, + } + + _, err = client.Resource(corednsv1alpha1.GroupVersion.WithResource("forwards")). + Namespace("kube-system"). + Create(context.Background(), mustForwardToUnstructured(kubeSystemForward), metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + err = wait.Poll(time.Second, time.Second*5, func() (bool, error) { + return testPluginInstancer.NewWithConfigCallCount() == 1, nil + }) + if err != nil { + t.Fatalf("Expected plugin instance to have been called exactly once: %s, plugin instance call count: %d", err, testPluginInstancer.NewWithConfigCallCount()) + } + + handler := testPluginInstancer.NewWithConfigArgsForCall(0) + if handler.ReceivedConfig.From != "system.test" { + t.Fatalf("Expected plugin to be created for zone: %s but was: %s", "system.test", handler.ReceivedConfig.From) + } + + _, ok := pluginInstanceMap.Get("system.test") + if !ok { + t.Fatal("Expected plugin lookup to succeed") + } + + _, ok = pluginInstanceMap.Get("crd.test") + if ok { + t.Fatal("Expected plugin lookup to fail") + } + + if err := controller.Stop(); err != nil { + t.Fatalf("Expected no error: %s", err) + } +} + +func setupControllerTestcase(t *testing.T, namespace string) (forwardCRDController, *fake.FakeDynamicClient, *TestPluginInstancer, *PluginInstanceMap) { + scheme := runtime.NewScheme() + scheme.AddKnownTypes(corednsv1alpha1.GroupVersion, &corednsv1alpha1.Forward{}) + customListKinds := map[schema.GroupVersionResource]string{ + corednsv1alpha1.GroupVersion.WithResource("forwards"): "ForwardList", + } + client := fake.NewSimpleDynamicClientWithCustomListKinds(scheme, customListKinds) + pluginMap := NewPluginInstanceMap() + testPluginInstancer := &TestPluginInstancer{} + controller := newForwardCRDController(context.Background(), client, scheme, namespace, pluginMap, func(cfg forward.ForwardConfig) (lifecyclePluginHandler, error) { + return testPluginInstancer.NewWithConfig(cfg) + }) + + go controller.Run(1) + + err := wait.Poll(time.Second, time.Second*5, func() (bool, error) { + return controller.HasSynced(), nil + }) + if err != nil { + t.Fatalf("Expected controller to have synced: %s", err) + } + + return controller, client, testPluginInstancer, pluginMap +} + +func mustForwardToUnstructured(forward *corednsv1alpha1.Forward) *unstructured.Unstructured { + forward.TypeMeta = metav1.TypeMeta{ + Kind: "Forward", + APIVersion: "coredns.io/v1alpha1", + } + + obj, err := runtime.DefaultUnstructuredConverter.ToUnstructured(forward) + if err != nil { + panic(fmt.Sprintf("coding error: unable to convert to unstructured: %s", err)) + } + return &unstructured.Unstructured{ + Object: obj, + } +} + +func mustUnstructuredToForward(obj *unstructured.Unstructured) *corednsv1alpha1.Forward { + forward := &corednsv1alpha1.Forward{} + err := runtime.DefaultUnstructuredConverter.FromUnstructured(obj.Object, forward) + if err != nil { + panic(fmt.Sprintf("coding error: unable to convert from unstructured: %s", err)) + } + return forward +} diff --git a/plugin/forwardcrd/fakes_test.go b/plugin/forwardcrd/fakes_test.go new file mode 100644 index 000000000..9b6cb0bcf --- /dev/null +++ b/plugin/forwardcrd/fakes_test.go @@ -0,0 +1,86 @@ +package forwardcrd + +import ( + "context" + "sync" + + "github.com/coredns/coredns/plugin/forward" + + "github.com/miekg/dns" +) + +type TestPluginHandler struct { + mutex sync.Mutex + ReceivedConfig forward.ForwardConfig + onStartupCallCount int + onShutdownCallCount int +} + +func (t *TestPluginHandler) ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) { + return 0, nil +} + +func (t *TestPluginHandler) Name() string { return "" } + +func (t *TestPluginHandler) OnStartup() error { + t.mutex.Lock() + defer t.mutex.Unlock() + t.onStartupCallCount++ + return nil +} + +func (t *TestPluginHandler) OnShutdown() error { + t.mutex.Lock() + defer t.mutex.Unlock() + t.onShutdownCallCount++ + return nil +} + +func (t *TestPluginHandler) OnStartupCallCount() int { + t.mutex.Lock() + defer t.mutex.Unlock() + return t.onStartupCallCount +} + +func (t *TestPluginHandler) OnShutdownCallCount() int { + t.mutex.Lock() + defer t.mutex.Unlock() + return t.onShutdownCallCount +} + +type TestPluginInstancer struct { + mutex sync.Mutex + testPluginHandlers []*TestPluginHandler +} + +func (t *TestPluginInstancer) NewWithConfig(config forward.ForwardConfig) (lifecyclePluginHandler, error) { + t.mutex.Lock() + defer t.mutex.Unlock() + + testPluginHandler := &TestPluginHandler{ + ReceivedConfig: config, + } + t.testPluginHandlers = append(t.testPluginHandlers, testPluginHandler) + return testPluginHandler, nil +} + +func (t *TestPluginInstancer) NewWithConfigArgsForCall(index int) *TestPluginHandler { + t.mutex.Lock() + defer t.mutex.Unlock() + + return t.testPluginHandlers[index] +} + +func (t *TestPluginInstancer) NewWithConfigCallCount() int { + t.mutex.Lock() + defer t.mutex.Unlock() + + return len(t.testPluginHandlers) +} + +type TestController struct { +} + +func (t *TestController) Run(threads int) {} +func (t *TestController) HasSynced() bool { return true } +func (t *TestController) Stop() error { return nil } diff --git a/plugin/forwardcrd/forwardcrd.go b/plugin/forwardcrd/forwardcrd.go new file mode 100644 index 000000000..533152a33 --- /dev/null +++ b/plugin/forwardcrd/forwardcrd.go @@ -0,0 +1,162 @@ +package forwardcrd + +import ( + "context" + "fmt" + "strings" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/forward" + corednsv1alpha1 "github.com/coredns/coredns/plugin/forwardcrd/apis/coredns/v1alpha1" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" +) + +// ForwardCRD represents a plugin instance that can watch Forward CRDs +// within a Kubernetes clusters to dynamically configure stub-domains to proxy +// requests to an upstream resolver. +type ForwardCRD struct { + Zones []string + APIServerEndpoint string + APIClientCert string + APIClientKey string + APICertAuth string + Namespace string + ClientConfig clientcmd.ClientConfig + APIConn forwardCRDController + Next plugin.Handler + + pluginInstanceMap *PluginInstanceMap +} + +// New returns a new ForwardCRD instance. +func New() *ForwardCRD { + return &ForwardCRD{ + Namespace: "kube-system", + + pluginInstanceMap: NewPluginInstanceMap(), + } +} + +// Name implements plugin.Handler. +func (k *ForwardCRD) Name() string { return "forwardcrd" } + +// ServeDNS implements plugin.Handler. +func (k *ForwardCRD) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + question := strings.ToLower(r.Question[0].Name) + + state := request.Request{W: w, Req: r} + if !k.match(state) { + return plugin.NextOrFailure(k.Name(), k.Next, ctx, w, r) + } + + var ( + offset int + end bool + ) + + for { + p, ok := k.pluginInstanceMap.Get(question[offset:]) + if ok { + a, b := p.ServeDNS(ctx, w, r) + return a, b + } + + offset, end = dns.NextLabel(question, offset) + if end { + break + } + } + + return plugin.NextOrFailure(k.Name(), k.Next, ctx, w, r) +} + +// Ready implements the ready.Readiness interface +func (k *ForwardCRD) Ready() bool { + return k.APIConn.HasSynced() +} + +// InitKubeCache initializes a new Kubernetes cache. +func (k *ForwardCRD) InitKubeCache(ctx context.Context) error { + config, err := k.getClientConfig() + if err != nil { + return err + } + + dynamicKubeClient, err := dynamic.NewForConfig(config) + if err != nil { + return fmt.Errorf("failed to create forwardcrd controller: %q", err) + } + + scheme := runtime.NewScheme() + err = corednsv1alpha1.AddToScheme(scheme) + if err != nil { + return fmt.Errorf("failed to create forwardcrd controller: %q", err) + } + + k.APIConn = newForwardCRDController(ctx, dynamicKubeClient, scheme, k.Namespace, k.pluginInstanceMap, func(cfg forward.ForwardConfig) (lifecyclePluginHandler, error) { + return forward.NewWithConfig(cfg) + }) + + return nil +} + +func (k *ForwardCRD) getClientConfig() (*rest.Config, error) { + if k.ClientConfig != nil { + return k.ClientConfig.ClientConfig() + } + loadingRules := &clientcmd.ClientConfigLoadingRules{} + overrides := &clientcmd.ConfigOverrides{} + clusterinfo := clientcmdapi.Cluster{} + authinfo := clientcmdapi.AuthInfo{} + + // Connect to API from in cluster + if k.APIServerEndpoint == "" { + cc, err := rest.InClusterConfig() + if err != nil { + return nil, err + } + cc.ContentType = "application/vnd.kubernetes.protobuf" + return cc, err + } + + // Connect to API from out of cluster + clusterinfo.Server = k.APIServerEndpoint + + if len(k.APICertAuth) > 0 { + clusterinfo.CertificateAuthority = k.APICertAuth + } + if len(k.APIClientCert) > 0 { + authinfo.ClientCertificate = k.APIClientCert + } + if len(k.APIClientKey) > 0 { + authinfo.ClientKey = k.APIClientKey + } + + overrides.ClusterInfo = clusterinfo + overrides.AuthInfo = authinfo + clientConfig := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, overrides) + + cc, err := clientConfig.ClientConfig() + if err != nil { + return nil, err + } + cc.ContentType = "application/vnd.kubernetes.protobuf" + return cc, err +} + +func (k *ForwardCRD) match(state request.Request) bool { + for _, zone := range k.Zones { + if plugin.Name(zone).Matches(state.Name()) || dns.Name(state.Name()) == dns.Name(zone) { + return true + } + } + + return false +} diff --git a/plugin/forwardcrd/forwardcrd_test.go b/plugin/forwardcrd/forwardcrd_test.go new file mode 100644 index 000000000..852bc14ee --- /dev/null +++ b/plugin/forwardcrd/forwardcrd_test.go @@ -0,0 +1,183 @@ +package forwardcrd + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/forward" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestDNSRequestForZone(t *testing.T) { + k, closeAll := setupForwardCRDTestcase(t, "") + defer closeAll() + + m := new(dns.Msg) + m.SetQuestion("crd.test.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + if _, err := k.ServeDNS(context.Background(), rec, m); err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if rec.Msg == nil || len(rec.Msg.Answer) != 1 { + t.Fatal("Expected an answer") + } + + if x := rec.Msg.Answer[0].Header().Name; x != "crd.test." { + t.Fatalf("Expected answer header name to be: %s, but got: %s", "crd.test.", x) + } + + if x := rec.Msg.Answer[0].(*dns.A).A.String(); x != "1.2.3.4" { + t.Fatalf("Expected answer ip to be: %s, but got: %s", "1.2.3.4", x) + } + + m = new(dns.Msg) + m.SetQuestion("other.test.", dns.TypeA) + rec = dnstest.NewRecorder(&test.ResponseWriter{}) + if _, err := k.ServeDNS(context.Background(), rec, m); err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if rec.Msg == nil || len(rec.Msg.Answer) != 1 { + t.Fatal("Expected an answer") + } + + if x := rec.Msg.Answer[0].Header().Name; x != "other.test." { + t.Fatalf("Expected answer header name to be: %s, but got: %s", "other.test.", x) + } + + if x := rec.Msg.Answer[0].(*dns.A).A.String(); x != "1.2.3.4" { + t.Fatalf("Expected answer ip to be: %s, but got: %s", "1.2.3.4", x) + } +} + +func TestDNSRequestForSubdomain(t *testing.T) { + k, closeAll := setupForwardCRDTestcase(t, "") + defer closeAll() + + m := new(dns.Msg) + m.SetQuestion("foo.crd.test.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + if _, err := k.ServeDNS(context.Background(), rec, m); err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if rec.Msg == nil || len(rec.Msg.Answer) != 1 { + t.Fatal("Expected an answer") + } + + if x := rec.Msg.Answer[0].Header().Name; x != "foo.crd.test." { + t.Fatalf("Expected answer header name to be: %s, but got: %s", "foo.crd.test.", x) + } + + if x := rec.Msg.Answer[0].(*dns.A).A.String(); x != "1.2.3.4" { + t.Fatalf("Expected answer ip to be: %s, but got: %s", "1.2.3.4", x) + } +} + +func TestDNSRequestForNonexistantZone(t *testing.T) { + k, closeAll := setupForwardCRDTestcase(t, "") + defer closeAll() + + m := new(dns.Msg) + m.SetQuestion("non-existant-zone.test.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + if rcode, err := k.ServeDNS(context.Background(), rec, m); err == nil || rcode != dns.RcodeServerFailure { + t.Fatalf("Expected to return rcode: %d and to error: %s", rcode, err) + } +} + +func TestDNSRequestForLimitedZones(t *testing.T) { + k, closeAll := setupForwardCRDTestcase(t, "crd.test.") + defer closeAll() + + m := new(dns.Msg) + m.SetQuestion("crd.test.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + if _, err := k.ServeDNS(context.Background(), rec, m); err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if rec.Msg == nil || len(rec.Msg.Answer) != 1 { + t.Fatal("Expected an answer") + } + + m = new(dns.Msg) + m.SetQuestion("sub.crd.test.", dns.TypeA) + rec = dnstest.NewRecorder(&test.ResponseWriter{}) + if _, err := k.ServeDNS(context.Background(), rec, m); err != nil { + t.Fatalf("Expected not to error: %s", err) + } + + if rec.Msg == nil || len(rec.Msg.Answer) != 1 { + t.Fatal("Expected an answer") + } + + m = new(dns.Msg) + m.SetQuestion("other.test.", dns.TypeA) + rec = dnstest.NewRecorder(&test.ResponseWriter{}) + if rcode, err := k.ServeDNS(context.Background(), rec, m); err == nil || rcode != dns.RcodeServerFailure { + t.Fatalf("Expected to return rcode: %d and to error: %s", rcode, err) + } + + m = new(dns.Msg) + m.SetQuestion("test.", dns.TypeA) + rec = dnstest.NewRecorder(&test.ResponseWriter{}) + if rcode, err := k.ServeDNS(context.Background(), rec, m); err == nil || rcode != dns.RcodeServerFailure { + t.Fatalf("Expected to return rcode: %d and to error: %s", rcode, err) + } +} + +func setupForwardCRDTestcase(t *testing.T, zone string) (*ForwardCRD, func()) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A(fmt.Sprintf("%s IN A 1.2.3.4", strings.ToLower(r.Question[0].Name)))) + w.WriteMsg(ret) + }) + + c := caddy.NewTestController("dns", fmt.Sprintf("forwardcrd %s", zone)) + c.ServerBlockKeys = []string{"."} + k, err := parseForwardCRD(c) + if err != nil { + t.Errorf("Expected not to error: %s", err) + } + + k.APIConn = &TestController{} + + forwardCRDTest, err := forward.NewWithConfig(forward.ForwardConfig{ + From: "crd.test", + To: []string{s.Addr}, + }) + if err != nil { + t.Errorf("Expected not to error: %s", err) + } + + forwardCRDTest.OnStartup() + + forwardOtherTest, err := forward.NewWithConfig(forward.ForwardConfig{ + From: "other.test.", + To: []string{s.Addr}, + }) + if err != nil { + t.Errorf("Expected not to error: %s", err) + } + + forwardOtherTest.OnStartup() + + k.pluginInstanceMap.Upsert("default/crd-test", "crd.test", forwardCRDTest) + k.pluginInstanceMap.Upsert("default/other-test", "other.test", forwardOtherTest) + + closeAll := func() { + s.Close() + forwardCRDTest.OnShutdown() + forwardOtherTest.OnShutdown() + } + return k, closeAll +} diff --git a/plugin/forwardcrd/manifests/crds/coredns.io_forwards.yaml b/plugin/forwardcrd/manifests/crds/coredns.io_forwards.yaml new file mode 100644 index 000000000..af6d6ab62 --- /dev/null +++ b/plugin/forwardcrd/manifests/crds/coredns.io_forwards.yaml @@ -0,0 +1,66 @@ + +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.4.1 + creationTimestamp: null + name: forwards.coredns.io +spec: + group: coredns.io + names: + kind: Forward + listKind: ForwardList + plural: forwards + singular: forward + scope: Namespaced + versions: + - additionalPrinterColumns: + - jsonPath: .spec.from + name: From + type: string + - jsonPath: .spec.to + name: To + type: string + name: v1alpha1 + schema: + openAPIV3Schema: + description: Forward represents a zone that should have its DNS requests forwarded + to an upstream DNS server within CoreDNS + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: ForwardSpec represents the spec of a Forward + properties: + from: + type: string + to: + items: + type: string + type: array + type: object + status: + description: ForwardStatus represents the status of a Forward + type: object + type: object + served: true + storage: true + subresources: {} +status: + acceptedNames: + kind: "" + plural: "" + conditions: [] + storedVersions: [] diff --git a/plugin/forwardcrd/plugin_map.go b/plugin/forwardcrd/plugin_map.go new file mode 100644 index 000000000..176e7c96b --- /dev/null +++ b/plugin/forwardcrd/plugin_map.go @@ -0,0 +1,83 @@ +package forwardcrd + +import ( + "sync" + + "github.com/coredns/coredns/plugin" +) + +// PluginInstanceMap represents a map of zones to coredns plugin instances that +// is thread-safe. It enables the forwardcrd plugin to save the state of +// which plugin instances should be delegated to for a given zone. +type PluginInstanceMap struct { + mutex *sync.RWMutex + zonesToPlugins map[string]lifecyclePluginHandler + keyToZones map[string]string +} + +// NewPluginInstanceMap returns a new instance of PluginInstanceMap. +func NewPluginInstanceMap() *PluginInstanceMap { + return &PluginInstanceMap{ + mutex: &sync.RWMutex{}, + zonesToPlugins: make(map[string]lifecyclePluginHandler), + keyToZones: make(map[string]string), + } +} + +// Upsert adds or updates the map with a zone to plugin handler mapping. If the +// same key is provided it will overwrite the old zone for that key with the +// new zone. Returns the plugin instance and true if the upsert was an update +// operation and not a create operation. +func (p *PluginInstanceMap) Upsert(key, zone string, handler lifecyclePluginHandler) (lifecyclePluginHandler, bool) { + var isUpdate bool + var oldPlugin lifecyclePluginHandler + p.mutex.Lock() + normalizedZone := plugin.Host(zone).NormalizeExact()[0] // there can only be one here, won't work with non-octet reverse + oldZone, ok := p.keyToZones[key] + if ok { + oldPlugin = p.zonesToPlugins[oldZone] + isUpdate = true + delete(p.zonesToPlugins, oldZone) + } + + p.keyToZones[key] = normalizedZone + p.zonesToPlugins[normalizedZone] = handler + p.mutex.Unlock() + return oldPlugin, isUpdate +} + +// Get gets the plugin handler provided a zone name. It will return true if the +// plugin handler exists and false if it does not exist. +func (p *PluginInstanceMap) Get(zone string) (lifecyclePluginHandler, bool) { + p.mutex.RLock() + normalizedZone := plugin.Host(zone).NormalizeExact()[0] // there can only be one here, won't work with non-octet reverse + handler, ok := p.zonesToPlugins[normalizedZone] + p.mutex.RUnlock() + return handler, ok +} + +// List lists all the plugin instances in the map. +func (p *PluginInstanceMap) List() []lifecyclePluginHandler { + p.mutex.RLock() + plugins := make([]lifecyclePluginHandler, len(p.zonesToPlugins)) + var i int + for _, v := range p.zonesToPlugins { + plugins[i] = v + i++ + } + p.mutex.RUnlock() + return plugins +} + +// Delete deletes the zone and plugin handler from the map. Returns the plugin +// instance that was deleted, useful for shutting down. Returns nil if no +// plugin was found. +func (p *PluginInstanceMap) Delete(key string) lifecyclePluginHandler { + p.mutex.Lock() + zone := p.keyToZones[key] + plugin := p.zonesToPlugins[zone] + delete(p.zonesToPlugins, zone) + delete(p.keyToZones, key) + p.mutex.Unlock() + return plugin +} diff --git a/plugin/forwardcrd/plugin_map_test.go b/plugin/forwardcrd/plugin_map_test.go new file mode 100644 index 000000000..eef0caf7d --- /dev/null +++ b/plugin/forwardcrd/plugin_map_test.go @@ -0,0 +1,89 @@ +package forwardcrd + +import ( + "sync" + "testing" + + "github.com/coredns/coredns/plugin/forward" +) + +func TestPluginMap(t *testing.T) { + pluginInstanceMap := NewPluginInstanceMap() + + zone1ForwardPlugin := forward.New() + zone2ForwardPlugin := forward.New() + + // Testing concurrency to ensure map is thread-safe + // i.e should run with `go test -race` + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + pluginInstanceMap.Upsert("default/some-dns-zone", "zone-1.test", zone1ForwardPlugin) + wg.Done() + }() + wg.Add(1) + go func() { + pluginInstanceMap.Upsert("default/another-dns-zone", "zone-2.test", zone2ForwardPlugin) + wg.Done() + }() + wg.Wait() + + if plugin, exists := pluginInstanceMap.Get("zone-1.test."); exists && plugin != zone1ForwardPlugin { + t.Fatalf("Expected plugin instance map to get plugin with address: %p but was: %p", zone1ForwardPlugin, plugin) + } + + if plugin, exists := pluginInstanceMap.Get("zone-2.test"); exists && plugin != zone2ForwardPlugin { + t.Fatalf("Expected plugin instance map to get plugin with address: %p but was: %p", zone2ForwardPlugin, plugin) + } + + if _, exists := pluginInstanceMap.Get("non-existant-zone.test"); exists { + t.Fatal("Expected plugin instance map to not return a plugin") + } + + // list + + plugins := pluginInstanceMap.List() + if len(plugins) != 2 { + t.Fatalf("Expected plugin instance map to have len %d, got: %d", 2, len(plugins)) + } + + if plugins[0] != zone1ForwardPlugin && plugins[0] != zone2ForwardPlugin { + t.Fatalf("Expected plugin instance map to list plugin[0] with address: %p or %p but was: %p", zone1ForwardPlugin, zone2ForwardPlugin, plugins[0]) + } + + if plugins[1] != zone1ForwardPlugin && plugins[1] != zone2ForwardPlugin { + t.Fatalf("Expected plugin instance map to list plugin[1] with address: %p or %p but was: %p", zone1ForwardPlugin, zone2ForwardPlugin, plugins[1]) + } + + // update record with the same key + + oldPlugin, update := pluginInstanceMap.Upsert("default/some-dns-zone", "new-zone-1.test", zone1ForwardPlugin) + + if !update { + t.Fatalf("Expected Upsert to be an update") + } + + if oldPlugin != zone1ForwardPlugin { + t.Fatalf("Expected Upsert to return the old plugin %#v, got: %#v", zone1ForwardPlugin, oldPlugin) + } + + if plugin, exists := pluginInstanceMap.Get("new-zone-1.test"); exists && plugin != zone1ForwardPlugin { + t.Fatalf("Expected plugin instance map to get plugin with address: %p but was: %p", zone1ForwardPlugin, plugin) + + } + if _, exists := pluginInstanceMap.Get("zone-1.test"); exists { + t.Fatalf("Expected plugin instance map to not get plugin with zone: %s", "zone-1.test") + } + + // delete record by key + + deletedPlugin := pluginInstanceMap.Delete("default/some-dns-zone") + + if _, exists := pluginInstanceMap.Get("new-zone-1.test"); exists { + t.Fatalf("Expected plugin instance map to not get plugin with zone: %s", "new-zone-1.test") + } + + if deletedPlugin == nil || deletedPlugin != zone1ForwardPlugin { + t.Fatalf("Expected Delete to return the deleted plugin %#v, got: %#v", zone1ForwardPlugin, deletedPlugin) + } +} diff --git a/plugin/forwardcrd/setup.go b/plugin/forwardcrd/setup.go new file mode 100644 index 000000000..6be322bc6 --- /dev/null +++ b/plugin/forwardcrd/setup.go @@ -0,0 +1,147 @@ +package forwardcrd + +import ( + "context" + "os" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/dnstap" + clog "github.com/coredns/coredns/plugin/pkg/log" + + "k8s.io/client-go/tools/clientcmd" + "k8s.io/klog/v2" +) + +const pluginName = "forwardcrd" + +var log = clog.NewWithPlugin(pluginName) + +func init() { + plugin.Register(pluginName, setup) +} + +func setup(c *caddy.Controller) error { + klog.SetOutput(os.Stdout) + + k, err := parseForwardCRD(c) + if err != nil { + return plugin.Error(pluginName, err) + } + + err = k.InitKubeCache(context.Background()) + if err != nil { + return plugin.Error(pluginName, err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + k.Next = next + return k + }) + + c.OnStartup(func() error { + go k.APIConn.Run(1) + + timeout := time.After(5 * time.Second) + ticker := time.NewTicker(100 * time.Millisecond) + for { + select { + case <-ticker.C: + if k.APIConn.HasSynced() { + return nil + } + case <-timeout: + return nil + } + } + }) + + c.OnStartup(func() error { + if taph := dnsserver.GetConfig(c).Handler("dnstap"); taph != nil { + if tapPlugin, ok := taph.(dnstap.Dnstap); ok { + k.APIConn.(*forwardCRDControl).tapPlugin = &tapPlugin + } + } + return nil + }) + + c.OnShutdown(func() error { + return k.APIConn.Stop() + }) + + return nil +} + +func parseForwardCRD(c *caddy.Controller) (*ForwardCRD, error) { + var ( + k *ForwardCRD + err error + i int + ) + + for c.Next() { + if i > 0 { + return nil, plugin.ErrOnce + } + i++ + k, err = parseStanza(c) + if err != nil { + return nil, err + } + } + + return k, nil +} + +func parseStanza(c *caddy.Controller) (*ForwardCRD, error) { + k := New() + + args := c.RemainingArgs() + k.Zones = plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys) + + for c.NextBlock() { + switch c.Val() { + case "endpoint": + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + k.APIServerEndpoint = args[0] + case "tls": + args := c.RemainingArgs() + if len(args) != 3 { + return nil, c.ArgErr() + } + k.APIClientCert, k.APIClientKey, k.APICertAuth = args[0], args[1], args[2] + case "kubeconfig": + args := c.RemainingArgs() + if len(args) != 1 && len(args) != 2 { + return nil, c.ArgErr() + } + overrides := &clientcmd.ConfigOverrides{} + if len(args) == 2 { + overrides.CurrentContext = args[1] + } + config := clientcmd.NewNonInteractiveDeferredLoadingClientConfig( + &clientcmd.ClientConfigLoadingRules{ExplicitPath: args[0]}, + overrides, + ) + k.ClientConfig = config + case "namespace": + args := c.RemainingArgs() + if len(args) == 0 { + k.Namespace = "" + } else if len(args) == 1 { + k.Namespace = args[0] + } else { + return nil, c.ArgErr() + } + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } + } + + return k, nil +} diff --git a/plugin/forwardcrd/setup_test.go b/plugin/forwardcrd/setup_test.go new file mode 100644 index 000000000..e5c7cb592 --- /dev/null +++ b/plugin/forwardcrd/setup_test.go @@ -0,0 +1,194 @@ +package forwardcrd + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin" +) + +func TestForwardCRDParse(t *testing.T) { + c := caddy.NewTestController("dns", `forwardcrd`) + k, err := parseForwardCRD(c) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + if k.Namespace != "kube-system" { + t.Errorf("Expected Namespace to be: %s\n but was: %s\n", "kube-system", k.Namespace) + } + + c = caddy.NewTestController("dns", `forwardcrd { + endpoint http://localhost:9090 + }`) + k, err = parseForwardCRD(c) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + if k.APIServerEndpoint != "http://localhost:9090" { + t.Errorf("Expected APIServerEndpoint to be: %s\n but was: %s\n", "http://localhost:9090", k.APIServerEndpoint) + } + + c = caddy.NewTestController("dns", `forwardcrd { + tls cert.crt key.key cacert.crt + }`) + k, err = parseForwardCRD(c) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + if k.APIClientCert != "cert.crt" { + t.Errorf("Expected APIClientCert to be: %s\n but was: %s\n", "cert.crt", k.APIClientCert) + } + if k.APIClientKey != "key.key" { + t.Errorf("Expected APIClientCert to be: %s\n but was: %s\n", "key.key", k.APIClientKey) + } + if k.APICertAuth != "cacert.crt" { + t.Errorf("Expected APICertAuth to be: %s\n but was: %s\n", "cacert.crt", k.APICertAuth) + } + + c = caddy.NewTestController("dns", `forwardcrd { + kubeconfig foo.kubeconfig + }`) + _, err = parseForwardCRD(c) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + c = caddy.NewTestController("dns", `forwardcrd { + kubeconfig foo.kubeconfig context + }`) + _, err = parseForwardCRD(c) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + c = caddy.NewTestController("dns", `forwardcrd example.org`) + k, err = parseForwardCRD(c) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + if len(k.Zones) != 1 || k.Zones[0] != "example.org." { + t.Fatalf("Expected Zones to consist of \"example.org.\" but was %v", k.Zones) + } + + c = caddy.NewTestController("dns", `forwardcrd`) + c.ServerBlockKeys = []string{"example.org"} + k, err = parseForwardCRD(c) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + if len(k.Zones) != 1 || k.Zones[0] != "example.org." { + t.Fatalf("Expected Zones to consist of \"example.org.\" but was %v", k.Zones) + } + + c = caddy.NewTestController("dns", `forwardcrd { + namespace + }`) + k, err = parseForwardCRD(c) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + if k.Namespace != "" { + t.Errorf("Expected Namespace to be: %q\n but was: %q\n", "", k.Namespace) + } + + c = caddy.NewTestController("dns", `forwardcrd { + namespace dns-system + }`) + k, err = parseForwardCRD(c) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + if k.Namespace != "dns-system" { + t.Errorf("Expected Namespace to be: %s\n but was: %s\n", "dns-system", k.Namespace) + } + + // negative + + c = caddy.NewTestController("dns", `forwardcrd { + endpoint http://localhost:9090 http://foo.bar:1024 + }`) + _, err = parseForwardCRD(c) + if err == nil { + t.Fatalf("Expected errors, but got nil") + } + if !strings.Contains(err.Error(), "Wrong argument count") { + t.Fatalf("Expected error containing \"Wrong argument count\", but got: %v", err.Error()) + } + + c = caddy.NewTestController("dns", `forwardcrd { + endpoint + }`) + _, err = parseForwardCRD(c) + if err == nil { + t.Fatalf("Expected errors, but got nil") + } + if !strings.Contains(err.Error(), "Wrong argument count") { + t.Fatalf("Expected error containing \"Wrong argument count\", but got: %v", err.Error()) + } + + c = caddy.NewTestController("dns", `forwardcrd { + tls foo bar + }`) + _, err = parseForwardCRD(c) + if err == nil { + t.Fatalf("Expected errors, but got nil") + } + if !strings.Contains(err.Error(), "Wrong argument count") { + t.Fatalf("Expected error containing \"Wrong argument count\", but got: %v", err.Error()) + } + + c = caddy.NewTestController("dns", `forwardcrd { + kubeconfig + }`) + _, err = parseForwardCRD(c) + if err == nil { + t.Fatalf("Expected errors, but got nil") + } + if !strings.Contains(err.Error(), "Wrong argument count") { + t.Fatalf("Expected error containing \"Wrong argument count\", but got: %v", err.Error()) + } + + c = caddy.NewTestController("dns", `forwardcrd { + kubeconfig too many args + }`) + _, err = parseForwardCRD(c) + if err == nil { + t.Fatalf("Expected errors, but got nil") + } + if !strings.Contains(err.Error(), "Wrong argument count") { + t.Fatalf("Expected error containing \"Wrong argument count\", but got: %v", err.Error()) + } + + c = caddy.NewTestController("dns", `forwardcrd { + namespace too many args + }`) + _, err = parseForwardCRD(c) + if err == nil { + t.Fatalf("Expected errors, but got nil") + } + if !strings.Contains(err.Error(), "Wrong argument count") { + t.Fatalf("Expected error containing \"Wrong argument count\", but got: %v", err.Error()) + } + + c = caddy.NewTestController("dns", `forwardcrd { + invalid + }`) + _, err = parseForwardCRD(c) + if err == nil { + t.Fatalf("Expected errors, but got nil") + } + if !strings.Contains(err.Error(), "unknown property") { + t.Fatalf("Expected error containing \"unknown property\", but got: %v", err.Error()) + } + + c = caddy.NewTestController("dns", `forwardcrd +forwardcrd`) + _, err = parseForwardCRD(c) + if err == nil { + t.Fatalf("Expected errors, but got nil") + } + if !strings.Contains(err.Error(), plugin.ErrOnce.Error()) { + t.Fatalf("Expected error containing \"%s\", but got: %v", plugin.ErrOnce.Error(), err.Error()) + } +}