From 00128bda4e2fa56ec4bf66e321bf3e00263d1188 Mon Sep 17 00:00:00 2001 From: Qasim Sarfraz Date: Thu, 15 Jul 2021 09:32:39 +0200 Subject: [PATCH] plugin/header: Introduce header plugin (#4752) * Add header plugin Signed-off-by: MQasimSarfraz * fix import format * improve README.md * Add codeowners for header plugin --- CODEOWNERS | 1 + plugin.cfg | 1 + plugin/header/README.md | 52 ++++++++++++++++++++ plugin/header/handler.go | 24 ++++++++++ plugin/header/header.go | 92 ++++++++++++++++++++++++++++++++++++ plugin/header/header_test.go | 84 ++++++++++++++++++++++++++++++++ plugin/header/setup.go | 50 ++++++++++++++++++++ plugin/header/setup_test.go | 53 +++++++++++++++++++++ 8 files changed, 357 insertions(+) create mode 100644 plugin/header/README.md create mode 100644 plugin/header/handler.go create mode 100644 plugin/header/header.go create mode 100644 plugin/header/header_test.go create mode 100644 plugin/header/setup.go create mode 100644 plugin/header/setup_test.go diff --git a/CODEOWNERS b/CODEOWNERS index e2562db12..0ef38a5a2 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -33,6 +33,7 @@ go.mod @miekg @chrisohaver @johnbelamaric @yongtang @stp-ip /plugin/geoip/ @miekg @snebel29 /plugin/grpc/ @inigohu @miekg @zouyee /plugin/health/ @fastest963 @miekg @zouyee +/plugin/header/ @miekg @mqasimsarfraz /plugin/hosts/ @johnbelamaric @pmoroney /plugin/k8s_external/ @miekg /plugin/kubernetes/ @bradbeam @chrisohaver @johnbelamaric @miekg @rajansandeep @yongtang @zouyee diff --git a/plugin.cfg b/plugin.cfg index a80ea97ef..628e71412 100644 --- a/plugin.cfg +++ b/plugin.cfg @@ -45,6 +45,7 @@ chaos:chaos loadbalance:loadbalance cache:cache rewrite:rewrite +header:header dnssec:dnssec autopath:autopath minimal:minimal diff --git a/plugin/header/README.md b/plugin/header/README.md new file mode 100644 index 000000000..862d23b02 --- /dev/null +++ b/plugin/header/README.md @@ -0,0 +1,52 @@ +# header + +## Name + +*header* - modifies the header for all the responses. + +## Description + +It ensures that the flags are in the desired state for all the responses. The modifications are made transparently for +the client. + +## Syntax + +~~~ +header { + ACTION FLAGS... + ACTION FLAGS... +} +~~~ + +* **ACTION** defines the state for dns flags. Actions are evaluated in the order they are defined so last one has the + most precedence. Allowed values are: + * `set` + * `clear` +* **FLAGS** are the dns flags that will be modified. Current supported flags include: + * `aa` - Authoritative + * `ra` - RecursionAvailable + * `rd` - RecursionDesired + +## Examples + +Make sure recursive available `ra` flag is set in all the responses: + +~~~ corefile +. { + header { + set ra + } +} +~~~ + +Make sure recursive available `ra` and authoritative `aa` flags are set and recursive desired is cleared in all the +responses: + +~~~ corefile +. { + header { + set ra aa + clear rd + } +} +~~~ diff --git a/plugin/header/handler.go b/plugin/header/handler.go new file mode 100644 index 000000000..a002c4fb1 --- /dev/null +++ b/plugin/header/handler.go @@ -0,0 +1,24 @@ +package header + +import ( + "context" + + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +// Header modifies dns.MsgHdr in the responses +type Header struct { + Rules []Rule + Next plugin.Handler +} + +// ServeDNS implements the plugin.Handler interface. +func (h Header) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + wr := ResponseHeaderWriter{ResponseWriter: w, Rules: h.Rules} + return plugin.NextOrFailure(h.Name(), h.Next, ctx, &wr, r) +} + +// Name implements the plugin.Handler interface. +func (h Header) Name() string { return "header" } diff --git a/plugin/header/header.go b/plugin/header/header.go new file mode 100644 index 000000000..660b07146 --- /dev/null +++ b/plugin/header/header.go @@ -0,0 +1,92 @@ +package header + +import ( + "fmt" + "strings" + + clog "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/miekg/dns" +) + +// Supported flags +const ( + authoritative = "aa" + recursionAvailable = "ra" + recursionDesired = "rd" +) + +var log = clog.NewWithPlugin("header") + +// ResponseHeaderWriter is a response writer that allows modifying dns.MsgHdr +type ResponseHeaderWriter struct { + dns.ResponseWriter + Rules []Rule +} + +// WriteMsg implements the dns.ResponseWriter interface. +func (r *ResponseHeaderWriter) WriteMsg(res *dns.Msg) error { + // handle all supported flags + for _, rule := range r.Rules { + switch rule.Flag { + case authoritative: + res.Authoritative = rule.State + case recursionAvailable: + res.RecursionAvailable = rule.State + case recursionDesired: + res.RecursionDesired = rule.State + } + } + + return r.ResponseWriter.WriteMsg(res) +} + +// Write implements the dns.ResponseWriter interface. +func (r *ResponseHeaderWriter) Write(buf []byte) (int, error) { + log.Warning("ResponseHeaderWriter called with Write: not ensuring headers") + n, err := r.ResponseWriter.Write(buf) + return n, err +} + +// Rule is used to set/clear Flag in dns.MsgHdr +type Rule struct { + Flag string + State bool +} + +func newRules(key string, args []string) ([]Rule, error) { + if key == "" { + return nil, fmt.Errorf("no flag action provided") + } + + if len(args) < 1 { + return nil, fmt.Errorf("invalid length for flags, at least one should be provided") + } + + var state bool + action := strings.ToLower(key) + switch action { + case "set": + state = true + case "clear": + state = false + default: + return nil, fmt.Errorf("unknown flag action=%s, should be set or clear", action) + } + + var rules []Rule + for _, arg := range args { + flag := strings.ToLower(arg) + switch flag { + case authoritative: + case recursionAvailable: + case recursionDesired: + default: + return nil, fmt.Errorf("unknown/unsupported flag=%s", flag) + } + rule := Rule{Flag: flag, State: state} + rules = append(rules, rule) + } + + return rules, nil +} diff --git a/plugin/header/header_test.go b/plugin/header/header_test.go new file mode 100644 index 000000000..396e26d07 --- /dev/null +++ b/plugin/header/header_test.go @@ -0,0 +1,84 @@ +package header + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestHeader(t *testing.T) { + wr := dnstest.NewRecorder(&test.ResponseWriter{}) + next := plugin.HandlerFunc(func(ctx context.Context, writer dns.ResponseWriter, msg *dns.Msg) (int, error) { + writer.WriteMsg(msg) + return dns.RcodeSuccess, nil + }) + + tests := []struct { + handler plugin.Handler + got func(msg *dns.Msg) bool + expected bool + }{ + { + handler: Header{ + Rules: []Rule{{Flag: recursionAvailable, State: true}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.RecursionAvailable + }, + expected: true, + }, + { + handler: Header{ + Rules: []Rule{{Flag: recursionAvailable, State: true}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.RecursionAvailable + }, + expected: true, + }, + { + handler: Header{ + Rules: []Rule{{Flag: recursionDesired, State: true}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.RecursionDesired + }, + expected: true, + }, + { + handler: Header{ + Rules: []Rule{{Flag: authoritative, State: true}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.Authoritative + }, + expected: true, + }, + } + + for i, test := range tests { + m := new(dns.Msg) + + _, err := test.handler.ServeDNS(context.TODO(), wr, m) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %s", i, err) + continue + } + + if test.got(m) != test.expected { + t.Errorf("Test %d: Expected flag state=%t, but got %t", i, test.expected, test.got(m)) + continue + } + + } + +} diff --git a/plugin/header/setup.go b/plugin/header/setup.go new file mode 100644 index 000000000..b0e67206a --- /dev/null +++ b/plugin/header/setup.go @@ -0,0 +1,50 @@ +package header + +import ( + "fmt" + + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/caddy" +) + +func init() { plugin.Register("header", setup) } + +func setup(c *caddy.Controller) error { + rules, err := parse(c) + if err != nil { + return plugin.Error("header", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return Header{ + Rules: rules, + Next: next, + } + }) + + return nil +} + +func parse(c *caddy.Controller) ([]Rule, error) { + for c.Next() { + var all []Rule + for c.NextBlock() { + v := c.Val() + args := c.RemainingArgs() + // set up rules + rules, err := newRules(v, args) + if err != nil { + return nil, fmt.Errorf("seting up rule: %w", err) + } + all = append(all, rules...) + } + + // return combined rules + if len(all) > 0 { + return all, nil + } + } + return nil, c.ArgErr() + +} diff --git a/plugin/header/setup_test.go b/plugin/header/setup_test.go new file mode 100644 index 000000000..a8570cc2e --- /dev/null +++ b/plugin/header/setup_test.go @@ -0,0 +1,53 @@ +package header + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupHeader(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedErrContent string + }{ + {`header {}`, true, "Wrong argument count or unexpected line ending after"}, + {`header { + set +}`, true, "invalid length for flags, at least one should be provided"}, + {`header { + foo +}`, true, "invalid length for flags, at least one should be provided"}, + {`header { + foo bar +}`, true, "unknown flag action=foo, should be set or clear"}, + {`header { + set ra +}`, false, ""}, + {`header { + set ra aa + clear rd +}`, false, ""}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found none for input %s", i, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + } +}