diff --git a/core/dnsserver/zdirectives.go b/core/dnsserver/zdirectives.go index 046600c54..aa8a13ae2 100644 --- a/core/dnsserver/zdirectives.go +++ b/core/dnsserver/zdirectives.go @@ -10,6 +10,7 @@ package dnsserver // (after) them during a request, but they must not // care what plugin above them are doing. var Directives = []string{ + "metadata", "tls", "reload", "nsid", diff --git a/core/plugin/zplugin.go b/core/plugin/zplugin.go index d1c0aaa73..630ece346 100644 --- a/core/plugin/zplugin.go +++ b/core/plugin/zplugin.go @@ -24,6 +24,7 @@ import ( _ "github.com/coredns/coredns/plugin/kubernetes" _ "github.com/coredns/coredns/plugin/loadbalance" _ "github.com/coredns/coredns/plugin/log" + _ "github.com/coredns/coredns/plugin/metadata" _ "github.com/coredns/coredns/plugin/metrics" _ "github.com/coredns/coredns/plugin/nsid" _ "github.com/coredns/coredns/plugin/pprof" diff --git a/plugin.cfg b/plugin.cfg index 42070373e..646aee92d 100644 --- a/plugin.cfg +++ b/plugin.cfg @@ -19,6 +19,7 @@ # Local plugin example: # log:log +metadata:metadata tls:tls reload:reload nsid:nsid diff --git a/plugin/metadata/README.md b/plugin/metadata/README.md new file mode 100644 index 000000000..32f58baa8 --- /dev/null +++ b/plugin/metadata/README.md @@ -0,0 +1,47 @@ +# metadata + +## Name + +*metadata* - enable a metadata collector. + +## Description + +By enabling *metadata* any plugin that implements [metadata.Provider interface](https://godoc.org/github.com/coredns/coredns/plugin/metadata#Provider) will be called for each DNS query, at being of the process for that query, in order to add it's own Metadata to context. The metadata collected will be available for all plugins handler, via the Context parameter provided in the ServeDNS function. +Metadata plugin is automatically adding the so-called default medatada (extracted from the query) to the context. Those default metadata are: {qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port} + + +## Syntax + +~~~ +metadata [ZONES... ] +~~~ + +## Plugins + +metadata.Provider interface needs to be implemented by each plugin willing to provide metadata information for other plugins. It will be called by metadata and gather the information from all plugins in context. +Note: this method should work quickly, because it is called for every request +from the metadata plugin. +If **ZONES** is specified then metadata add is limited by zones. Metadata is added to every context going through metadata.Provider if **ZONES** are not specified. + + +## Examples + +Enable metadata for all requests. Rewrite uses one of the provided by default metadata variables. + +~~~ corefile +. { + metadata + rewrite edns0 local set 0xffee {client_ip} + forward . 8.8.8.8:53 +} +~~~ + +Add metadata for all requests within `example.org.`. Rewrite uses one of provided by default metadata variables. Any other requests won't have metadata. + +~~~ corefile +. { + metadata example.org + rewrite edns0 local set 0xffee {client_ip} + forward . 8.8.8.8:53 +} +~~~ diff --git a/plugin/metadata/metadata.go b/plugin/metadata/metadata.go new file mode 100644 index 000000000..1e840d3fd --- /dev/null +++ b/plugin/metadata/metadata.go @@ -0,0 +1,55 @@ +package metadata + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/variables" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Metadata implements collecting metadata information from all plugins that +// implement the Provider interface. +type Metadata struct { + Zones []string + Providers []Provider + Next plugin.Handler +} + +// Name implements the Handler interface. +func (m *Metadata) Name() string { return "metadata" } + +// ServeDNS implements the plugin.Handler interface. +func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + + md, ctx := newMD(ctx) + + state := request.Request{W: w, Req: r} + if plugin.Zones(m.Zones).Matches(state.Name()) != "" { + // Go through all Providers and collect metadata + for _, provider := range m.Providers { + for _, varName := range provider.MetadataVarNames() { + if val, ok := provider.Metadata(ctx, w, r, varName); ok { + md.setValue(varName, val) + } + } + } + } + + rcode, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, w, r) + + return rcode, err +} + +// MetadataVarNames implements the plugin.Provider interface. +func (m *Metadata) MetadataVarNames() []string { return variables.All } + +// Metadata implements the plugin.Provider interface. +func (m *Metadata) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, varName string) (interface{}, bool) { + if val, err := variables.GetValue(varName, w, r); err == nil { + return val, true + } + return nil, false +} diff --git a/plugin/metadata/metadata_test.go b/plugin/metadata/metadata_test.go new file mode 100644 index 000000000..413ba874e --- /dev/null +++ b/plugin/metadata/metadata_test.go @@ -0,0 +1,79 @@ +package metadata + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// testProvider implements fake Providers. Plugins which inmplement Provider interface +type testProvider map[string]interface{} + +func (m testProvider) MetadataVarNames() []string { + keys := []string{} + for k := range m { + keys = append(keys, k) + } + return keys +} + +func (m testProvider) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, key string) (val interface{}, ok bool) { + value, ok := m[key] + return value, ok +} + +// testHandler implements plugin.Handler +type testHandler struct{ ctx context.Context } + +func (m *testHandler) Name() string { return "testHandler" } + +func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m.ctx = ctx + return 0, nil +} + +func TestMetadataServDns(t *testing.T) { + expectedMetadata := []testProvider{ + testProvider{"testkey1": "testvalue1"}, + testProvider{"testkey2": 2, "testkey3": "testvalue3"}, + } + // Create fake Providers based on expectedMetadata + providers := []Provider{} + for _, e := range expectedMetadata { + providers = append(providers, e) + } + // Fake handler which stores the resulting context + next := &testHandler{} + + metadata := Metadata{ + Zones: []string{"."}, + Providers: providers, + Next: next, + } + metadata.ServeDNS(context.TODO(), &test.ResponseWriter{}, new(dns.Msg)) + + // Verify that next plugin can find metadata in context from all Providers + for _, expected := range expectedMetadata { + md, ok := FromContext(next.ctx) + if !ok { + t.Fatalf("Metadata is expected but not present inside the context") + } + for expKey, expVal := range expected { + metadataVal, valOk := md.Value(expKey) + if !valOk { + t.Fatalf("Value by key %v can't be retrieved", expKey) + } + if metadataVal != expVal { + t.Errorf("Expected value %v, but got %v", expVal, metadataVal) + } + } + wrongKey := "wrong_key" + metadataVal, ok := md.Value(wrongKey) + if ok { + t.Fatalf("Value by key %v is not expected to be recieved, but got: %v", wrongKey, metadataVal) + } + } +} diff --git a/plugin/metadata/metadataer.go b/plugin/metadata/metadataer.go new file mode 100644 index 000000000..bff12e92d --- /dev/null +++ b/plugin/metadata/metadataer.go @@ -0,0 +1,53 @@ +package metadata + +import ( + "context" + + "github.com/miekg/dns" +) + +// Provider interface needs to be implemented by each plugin willing to provide +// metadata information for other plugins. +// Note: this method should work quickly, because it is called for every request +// from the metadata plugin. +type Provider interface { + // List of variables which are provided by current Provider. Must remain constant. + MetadataVarNames() []string + // Metadata is expected to return a value with metadata information by the key + // from 4th argument. Value can be later retrieved from context by any other plugin. + // If value is not available by some reason returned boolean value should be false. + Metadata(context.Context, dns.ResponseWriter, *dns.Msg, string) (interface{}, bool) +} + +// MD is metadata information storage +type MD map[string]interface{} + +// metadataKey defines the type of key that is used to save metadata into the context +type metadataKey struct{} + +// newMD initializes MD and attaches it to context +func newMD(ctx context.Context) (MD, context.Context) { + m := MD{} + return m, context.WithValue(ctx, metadataKey{}, m) +} + +// FromContext retrieves MD struct from context. +func FromContext(ctx context.Context) (md MD, ok bool) { + if metadata := ctx.Value(metadataKey{}); metadata != nil { + if md, ok := metadata.(MD); ok { + return md, true + } + } + return MD{}, false +} + +// Value returns metadata value by key. +func (m MD) Value(key string) (value interface{}, ok bool) { + value, ok = m[key] + return value, ok +} + +// setValue adds metadata value. +func (m MD) setValue(key string, val interface{}) { + m[key] = val +} diff --git a/plugin/metadata/metadataer_test.go b/plugin/metadata/metadataer_test.go new file mode 100644 index 000000000..53096feb8 --- /dev/null +++ b/plugin/metadata/metadataer_test.go @@ -0,0 +1,47 @@ +package metadata + +import ( + "context" + "reflect" + "testing" +) + +func TestMD(t *testing.T) { + tests := []struct { + addValues map[string]interface{} + expectedValues map[string]interface{} + }{ + { + // Add initial metadata key/vals + map[string]interface{}{"key1": "val1", "key2": 2}, + map[string]interface{}{"key1": "val1", "key2": 2}, + }, + { + // Add additional key/vals. + map[string]interface{}{"key3": 3, "key4": 4.5}, + map[string]interface{}{"key1": "val1", "key2": 2, "key3": 3, "key4": 4.5}, + }, + } + + // Using one same md and ctx for all test cases + ctx := context.TODO() + md, ctx := newMD(ctx) + + for i, tc := range tests { + for k, v := range tc.addValues { + md.setValue(k, v) + } + if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(md)) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, md) + } + + // Make sure that MD is recieved from context successfullly + mdFromContext, ok := FromContext(ctx) + if !ok { + t.Errorf("Test %d: MD is not recieved from the context", i) + } + if !reflect.DeepEqual(md, mdFromContext) { + t.Errorf("Test %d: MD recieved from context differs from initial. Initial: %v, from context: %v", i, md, mdFromContext) + } + } +} diff --git a/plugin/metadata/setup.go b/plugin/metadata/setup.go new file mode 100644 index 000000000..33a153a2c --- /dev/null +++ b/plugin/metadata/setup.go @@ -0,0 +1,71 @@ +package metadata + +import ( + "fmt" + + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin("metadata", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + m, err := metadataParse(c) + if err != nil { + return err + } + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + m.Next = next + return m + }) + + c.OnStartup(func() error { + plugins := dnsserver.GetConfig(c).Handlers() + // Collect all plugins which implement Provider interface + metadataVariables := map[string]bool{} + for _, p := range plugins { + if met, ok := p.(Provider); ok { + for _, varName := range met.MetadataVarNames() { + if _, ok := metadataVariables[varName]; ok { + return fmt.Errorf("Metadata variable '%v' has duplicates", varName) + } + metadataVariables[varName] = true + } + m.Providers = append(m.Providers, met) + } + } + return nil + }) + + return nil +} + +func metadataParse(c *caddy.Controller) (*Metadata, error) { + m := &Metadata{} + c.Next() + zones := c.RemainingArgs() + + if len(zones) != 0 { + m.Zones = zones + for i := 0; i < len(m.Zones); i++ { + m.Zones[i] = plugin.Host(m.Zones[i]).Normalize() + } + } else { + m.Zones = make([]string, len(c.ServerBlockKeys)) + for i := 0; i < len(c.ServerBlockKeys); i++ { + m.Zones[i] = plugin.Host(c.ServerBlockKeys[i]).Normalize() + } + } + + if c.NextBlock() || c.Next() { + return nil, plugin.Error("metadata", c.ArgErr()) + } + return m, nil +} diff --git a/plugin/metadata/setup_test.go b/plugin/metadata/setup_test.go new file mode 100644 index 000000000..362a1bbf3 --- /dev/null +++ b/plugin/metadata/setup_test.go @@ -0,0 +1,70 @@ +package metadata + +import ( + "reflect" + "testing" + + "github.com/mholt/caddy" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + zones []string + shouldErr bool + }{ + {"metadata", []string{}, false}, + {"metadata example.com.", []string{"example.com."}, false}, + {"metadata example.com. net.", []string{"example.com.", "net."}, false}, + + {"metadata example.com. { some_param }", []string{}, true}, + {"metadata\nmetadata", []string{}, true}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Setup call expected error but found none for input %s", i, test.input) + } + + if !test.shouldErr && err != nil { + t.Errorf("Test %d: Setup call expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + } +} + +func TestSetupHealth(t *testing.T) { + tests := []struct { + input string + zones []string + shouldErr bool + }{ + {"metadata", []string{}, false}, + {"metadata example.com.", []string{"example.com."}, false}, + {"metadata example.com. net.", []string{"example.com.", "net."}, false}, + + {"metadata example.com. { some_param }", []string{}, true}, + {"metadata\nmetadata", []string{}, true}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + m, err := metadataParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found none for input %s", i, test.input) + } + + if !test.shouldErr && err != nil { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !test.shouldErr && err == nil { + if !reflect.DeepEqual(test.zones, m.Zones) { + t.Errorf("Test %d: Expected zones %s. Zones were: %v", i, test.zones, m.Zones) + } + } + } +} diff --git a/plugin/pkg/variables/variables.go b/plugin/pkg/variables/variables.go new file mode 100644 index 000000000..da1dccbee --- /dev/null +++ b/plugin/pkg/variables/variables.go @@ -0,0 +1,109 @@ +package variables + +import ( + "encoding/binary" + "fmt" + "net" + "strconv" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +const ( + queryName = "qname" + queryType = "qtype" + clientIP = "client_ip" + clientPort = "client_port" + protocol = "protocol" + serverIP = "server_ip" + serverPort = "server_port" +) + +// All is a list of available variables provided by GetMetadataValue +var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverIP, serverPort} + +// GetValue calculates and returns the data specified by the variable name. +// Supported varNames are listed in allProvidedVars. +func GetValue(varName string, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { + req := request.Request{W: w, Req: r} + switch varName { + case queryName: + //Query name is written as ascii string + return []byte(req.QName()), nil + + case queryType: + return uint16ToWire(req.QType()), nil + + case clientIP: + return ipToWire(req.Family(), req.IP()) + + case clientPort: + return portToWire(req.Port()) + + case protocol: + // Proto is written as ascii string + return []byte(req.Proto()), nil + + case serverIP: + ip, _, err := net.SplitHostPort(w.LocalAddr().String()) + if err != nil { + ip = w.RemoteAddr().String() + } + return ipToWire(family(w.RemoteAddr()), ip) + + case serverPort: + _, port, err := net.SplitHostPort(w.LocalAddr().String()) + if err != nil { + port = "0" + } + return portToWire(port) + } + + return nil, fmt.Errorf("unable to extract data for variable %s", varName) +} + +// uint16ToWire writes unit16 to wire/binary format +func uint16ToWire(data uint16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, uint16(data)) + return buf +} + +// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6. +func ipToWire(family int, ipAddr string) ([]byte, error) { + + switch family { + case 1: + return net.ParseIP(ipAddr).To4(), nil + case 2: + return net.ParseIP(ipAddr).To16(), nil + } + return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family) +} + +// portToWire writes port to wire/binary format, 2 bytes +func portToWire(portStr string) ([]byte, error) { + + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, err + } + return uint16ToWire(uint16(port)), nil +} + +// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6. +func family(ip net.Addr) int { + var a net.IP + if i, ok := ip.(*net.UDPAddr); ok { + a = i.IP + } + if i, ok := ip.(*net.TCPAddr); ok { + a = i.IP + } + if a.To4() != nil { + return 1 + } + return 2 +} diff --git a/plugin/pkg/variables/variables_test.go b/plugin/pkg/variables/variables_test.go new file mode 100644 index 000000000..939add323 --- /dev/null +++ b/plugin/pkg/variables/variables_test.go @@ -0,0 +1,80 @@ +package variables + +import ( + "bytes" + "testing" + + "github.com/coredns/coredns/plugin/test" + "github.com/miekg/dns" +) + +func TestGetValue(t *testing.T) { + // test.ResponseWriter has the following values: + // The remote will always be 10.240.0.1 and port 40212. + // The local address is always 127.0.0.1 and port 53. + tests := []struct { + varName string + expectedValue []byte + shouldErr bool + }{ + { + queryName, + []byte("example.com."), + false, + }, + { + queryType, + []byte{0x00, 0x01}, + false, + }, + { + clientIP, + []byte{10, 240, 0, 1}, + false, + }, + { + clientPort, + []byte{0x9D, 0x14}, + false, + }, + { + protocol, + []byte("udp"), + false, + }, + { + serverIP, + []byte{127, 0, 0, 1}, + false, + }, + { + serverPort, + []byte{0, 53}, + false, + }, + { + "wrong_var", + []byte{}, + true, + }, + } + + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + + value, err := GetValue(tc.varName, &test.ResponseWriter{}, m) + + if tc.shouldErr && err == nil { + t.Errorf("Test %d: Expected error, but didn't recieve", i) + } + if !tc.shouldErr && err != nil { + t.Errorf("Test %d: Expected no error, but got error: %v", i, err.Error()) + } + + if !bytes.Equal(tc.expectedValue, value) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValue, value) + } + } +} diff --git a/plugin/rewrite/README.md b/plugin/rewrite/README.md index 680e69722..4e2e49a3a 100644 --- a/plugin/rewrite/README.md +++ b/plugin/rewrite/README.md @@ -206,13 +206,17 @@ rewrites the first local option with code 0xffee, setting the data to "abcd". Eq } ~~~ -* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables: +* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables by default: {qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}. +Any plugin that can provide it's own additional variables by implementing metadata.Provider interface. If you are going to use metadata variables then metadata plugin must be enabled. Example: -~~~ -rewrite edns0 local set 0xffee {client_ip} +~~~ corefile +. { + metadata + rewrite edns0 local set 0xffee {client_ip} +} ~~~ ### EDNS0_NSID diff --git a/plugin/rewrite/class.go b/plugin/rewrite/class.go index 2e54f515c..b04dabce2 100644 --- a/plugin/rewrite/class.go +++ b/plugin/rewrite/class.go @@ -1,6 +1,7 @@ package rewrite import ( + "context" "fmt" "strings" @@ -27,7 +28,7 @@ func newClassRule(nextAction string, args ...string) (Rule, error) { } // Rewrite rewrites the the current request. -func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *classRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { if rule.fromClass > 0 && rule.toClass > 0 { if r.Question[0].Qclass == rule.fromClass { r.Question[0].Qclass = rule.toClass diff --git a/plugin/rewrite/edns0.go b/plugin/rewrite/edns0.go index 2fd42cb67..f8b65d468 100644 --- a/plugin/rewrite/edns0.go +++ b/plugin/rewrite/edns0.go @@ -2,13 +2,15 @@ package rewrite import ( - "encoding/binary" + "context" "encoding/hex" "fmt" "net" "strconv" "strings" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/pkg/variables" "github.com/coredns/coredns/request" "github.com/miekg/dns" ) @@ -46,7 +48,7 @@ func setupEdns0Opt(r *dns.Msg) *dns.OPT { } // Rewrite will alter the request EDNS0 NSID option -func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0NsidRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -83,7 +85,7 @@ func (rule *edns0NsidRule) GetResponseRule() ResponseRule { } // Rewrite will alter the request EDNS0 local options -func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0LocalRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -146,7 +148,9 @@ func newEdns0Rule(mode string, args ...string) (Rule, error) { } //Check for variable option if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") { - return newEdns0VariableRule(mode, action, args[2], args[3]) + // Remove first and last runes + variable := args[3][1 : len(args[3])-1] + return newEdns0VariableRule(mode, action, args[2], variable) } return newEdns0LocalRule(mode, action, args[2], args[3]) case "nsid": @@ -186,102 +190,28 @@ func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRu if err != nil { return nil, err } - //Validate - if !isValidVariable(variable) { - return nil, fmt.Errorf("unsupported variable name %q", variable) - } return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil } -// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6. -func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) { - - switch family { - case 1: - return net.ParseIP(ipAddr).To4(), nil - case 2: - return net.ParseIP(ipAddr).To16(), nil - } - return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family) -} - -// uint16ToWire writes unit16 to wire/binary format -func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte { - buf := make([]byte, 2) - binary.BigEndian.PutUint16(buf, uint16(data)) - return buf -} - -// portToWire writes port to wire/binary format, 2 bytes -func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) { - - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return nil, err - } - return rule.uint16ToWire(uint16(port)), nil -} - -// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6. -func (rule *edns0VariableRule) family(ip net.Addr) int { - var a net.IP - if i, ok := ip.(*net.UDPAddr); ok { - a = i.IP - } - if i, ok := ip.(*net.TCPAddr); ok { - a = i.IP - } - if a.To4() != nil { - return 1 - } - return 2 -} - // ruleData returns the data specified by the variable -func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { - - req := request.Request{W: w, Req: r} - switch rule.variable { - case queryName: - //Query name is written as ascii string - return []byte(req.QName()), nil - - case queryType: - return rule.uint16ToWire(req.QType()), nil - - case clientIP: - return rule.ipToWire(req.Family(), req.IP()) - - case clientPort: - return rule.portToWire(req.Port()) - - case protocol: - // Proto is written as ascii string - return []byte(req.Proto()), nil - - case serverIP: - ip, _, err := net.SplitHostPort(w.LocalAddr().String()) - if err != nil { - ip = w.RemoteAddr().String() +func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { + if md, ok := metadata.FromContext(ctx); ok { + if value, ok := md.Value(rule.variable); ok { + if v, ok := value.([]byte); ok { + return v, nil + } } - return rule.ipToWire(rule.family(w.RemoteAddr()), ip) - - case serverPort: - _, port, err := net.SplitHostPort(w.LocalAddr().String()) - if err != nil { - port = "0" - } - return rule.portToWire(port) + } else { // No metadata available means metadata plugin is disabled. Try to get the value directly. + return variables.GetValue(rule.variable, w, r) } - return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable) } // Rewrite will alter the request EDNS0 local options with specified variables -func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0VariableRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored - data, err := rule.ruleData(w, r) + data, err := rule.ruleData(ctx, w, r) if err != nil || data == nil { return result } @@ -324,21 +254,6 @@ func (rule *edns0VariableRule) GetResponseRule() ResponseRule { return ResponseRule{} } -func isValidVariable(variable string) bool { - switch variable { - case - queryName, - queryType, - clientIP, - clientPort, - protocol, - serverIP, - serverPort: - return true - } - return false -} - // ends0SubnetRule is a rewrite rule for EDNS0 subnet options type edns0SubnetRule struct { mode string @@ -400,7 +315,7 @@ func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg, } // Rewrite will alter the request EDNS0 subnet option -func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *edns0SubnetRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -446,17 +361,6 @@ const ( Append = "append" ) -// Supported local EDNS0 variables -const ( - queryName = "{qname}" - queryType = "{qtype}" - clientIP = "{client_ip}" - clientPort = "{client_port}" - protocol = "{protocol}" - serverIP = "{server_ip}" - serverPort = "{server_port}" -) - // Subnet maximum bit mask length const ( maxV4BitMaskLen = 32 diff --git a/plugin/rewrite/name.go b/plugin/rewrite/name.go index a34b4804b..4f9bb14f3 100644 --- a/plugin/rewrite/name.go +++ b/plugin/rewrite/name.go @@ -1,6 +1,7 @@ package rewrite import ( + "context" "fmt" "regexp" "strconv" @@ -56,7 +57,7 @@ const ( // Rewrite rewrites the current request based upon exact match of the name // in the question section of the request -func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { if rule.From == r.Question[0].Name { r.Question[0].Name = rule.To return RewriteDone @@ -65,7 +66,7 @@ func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { } // Rewrite rewrites the current request when the name begins with the matching string -func (rule *prefixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { if strings.HasPrefix(r.Question[0].Name, rule.Prefix) { r.Question[0].Name = rule.Replacement + strings.TrimLeft(r.Question[0].Name, rule.Prefix) return RewriteDone @@ -74,7 +75,7 @@ func (rule *prefixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { } // Rewrite rewrites the current request when the name ends with the matching string -func (rule *suffixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { if strings.HasSuffix(r.Question[0].Name, rule.Suffix) { r.Question[0].Name = strings.TrimRight(r.Question[0].Name, rule.Suffix) + rule.Replacement return RewriteDone @@ -84,7 +85,7 @@ func (rule *suffixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { // Rewrite rewrites the current request based upon partial match of the // name in the question section of the request -func (rule *substringNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { if strings.Contains(r.Question[0].Name, rule.Substring) { r.Question[0].Name = strings.Replace(r.Question[0].Name, rule.Substring, rule.Replacement, -1) return RewriteDone @@ -94,7 +95,7 @@ func (rule *substringNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result // Rewrite rewrites the current request when the name in the question // section of the request matches a regular expression -func (rule *regexNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *regexNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { regexGroups := rule.Pattern.FindStringSubmatch(r.Question[0].Name) if len(regexGroups) == 0 { return RewriteIgnored diff --git a/plugin/rewrite/rewrite.go b/plugin/rewrite/rewrite.go index 9b61ee123..422ebd9c6 100644 --- a/plugin/rewrite/rewrite.go +++ b/plugin/rewrite/rewrite.go @@ -42,7 +42,7 @@ type Rewrite struct { func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { wr := NewResponseReverter(w, r) for _, rule := range rw.Rules { - switch result := rule.Rewrite(w, r); result { + switch result := rule.Rewrite(ctx, w, r); result { case RewriteDone: respRule := rule.GetResponseRule() if respRule.Active == true { @@ -76,7 +76,7 @@ func (rw Rewrite) Name() string { return "rewrite" } // Rule describes a rewrite rule. type Rule interface { // Rewrite rewrites the current request. - Rewrite(dns.ResponseWriter, *dns.Msg) Result + Rewrite(context.Context, dns.ResponseWriter, *dns.Msg) Result // Mode returns the processing mode stop or continue. Mode() string // GetResponseRule returns the rule to rewrite response with, if any. diff --git a/plugin/rewrite/rewrite_test.go b/plugin/rewrite/rewrite_test.go index 56c446f49..b35543b9b 100644 --- a/plugin/rewrite/rewrite_test.go +++ b/plugin/rewrite/rewrite_test.go @@ -71,7 +71,7 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})}, {[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})}, {[]string{"edns0", "nsid", "foo"}, true, nil}, - {[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, @@ -79,7 +79,7 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, - {[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, @@ -87,7 +87,7 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, - {[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, diff --git a/plugin/rewrite/type.go b/plugin/rewrite/type.go index ec36b0b0a..c5c545485 100644 --- a/plugin/rewrite/type.go +++ b/plugin/rewrite/type.go @@ -2,6 +2,7 @@ package rewrite import ( + "context" "fmt" "strings" @@ -28,7 +29,7 @@ func newTypeRule(nextAction string, args ...string) (Rule, error) { } // Rewrite rewrites the the current request. -func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { +func (rule *typeRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result { if rule.fromType > 0 && rule.toType > 0 { if r.Question[0].Qtype == rule.fromType { r.Question[0].Qtype = rule.toType