plugin/metadata: add metadata plugin (#1894)

* plugin/metadata: add metadata plugin

* plugin/metadata: Add MD struct, refactor code, fix doc

* plugin/metadata: simplify metadata key

* plugin/metadata: improve setup_test

* Support of metadata by rewrite plugin. Move calculated variables to metadata.

* Move variables from metadata to pkg, add UTs, READMEs change, metadata small fixes

* Add client port validation to variables_test

* plugin/metadata: improve README

* plugin/metadata: rename methods

* plugin/metadata: Update Metadataer interface, update doc, cosmetic code changes

* plugin/metadata: move colllisions check to OnStartup(). Fix default variables metadataer.

* plugin/metadata: Fix comment for method setValue

* plugin/metadata: change variables order to fix linter warning

* plugin/metadata: rename Metadataer to Provider
This commit is contained in:
Eugen Kleiner 2018-06-29 12:44:16 +03:00 committed by Miek Gieben
parent dae506b563
commit 17d807f05f
19 changed files with 655 additions and 130 deletions

View file

@ -10,6 +10,7 @@ package dnsserver
// (after) them during a request, but they must not
// care what plugin above them are doing.
var Directives = []string{
"metadata",
"tls",
"reload",
"nsid",

View file

@ -24,6 +24,7 @@ import (
_ "github.com/coredns/coredns/plugin/kubernetes"
_ "github.com/coredns/coredns/plugin/loadbalance"
_ "github.com/coredns/coredns/plugin/log"
_ "github.com/coredns/coredns/plugin/metadata"
_ "github.com/coredns/coredns/plugin/metrics"
_ "github.com/coredns/coredns/plugin/nsid"
_ "github.com/coredns/coredns/plugin/pprof"

View file

@ -19,6 +19,7 @@
# Local plugin example:
# log:log
metadata:metadata
tls:tls
reload:reload
nsid:nsid

47
plugin/metadata/README.md Normal file
View file

@ -0,0 +1,47 @@
# metadata
## Name
*metadata* - enable a metadata collector.
## Description
By enabling *metadata* any plugin that implements [metadata.Provider interface](https://godoc.org/github.com/coredns/coredns/plugin/metadata#Provider) will be called for each DNS query, at being of the process for that query, in order to add it's own Metadata to context. The metadata collected will be available for all plugins handler, via the Context parameter provided in the ServeDNS function.
Metadata plugin is automatically adding the so-called default medatada (extracted from the query) to the context. Those default metadata are: {qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}
## Syntax
~~~
metadata [ZONES... ]
~~~
## Plugins
metadata.Provider interface needs to be implemented by each plugin willing to provide metadata information for other plugins. It will be called by metadata and gather the information from all plugins in context.
Note: this method should work quickly, because it is called for every request
from the metadata plugin.
If **ZONES** is specified then metadata add is limited by zones. Metadata is added to every context going through metadata.Provider if **ZONES** are not specified.
## Examples
Enable metadata for all requests. Rewrite uses one of the provided by default metadata variables.
~~~ corefile
. {
metadata
rewrite edns0 local set 0xffee {client_ip}
forward . 8.8.8.8:53
}
~~~
Add metadata for all requests within `example.org.`. Rewrite uses one of provided by default metadata variables. Any other requests won't have metadata.
~~~ corefile
. {
metadata example.org
rewrite edns0 local set 0xffee {client_ip}
forward . 8.8.8.8:53
}
~~~

View file

@ -0,0 +1,55 @@
package metadata
import (
"context"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/variables"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
// Metadata implements collecting metadata information from all plugins that
// implement the Provider interface.
type Metadata struct {
Zones []string
Providers []Provider
Next plugin.Handler
}
// Name implements the Handler interface.
func (m *Metadata) Name() string { return "metadata" }
// ServeDNS implements the plugin.Handler interface.
func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
md, ctx := newMD(ctx)
state := request.Request{W: w, Req: r}
if plugin.Zones(m.Zones).Matches(state.Name()) != "" {
// Go through all Providers and collect metadata
for _, provider := range m.Providers {
for _, varName := range provider.MetadataVarNames() {
if val, ok := provider.Metadata(ctx, w, r, varName); ok {
md.setValue(varName, val)
}
}
}
}
rcode, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, w, r)
return rcode, err
}
// MetadataVarNames implements the plugin.Provider interface.
func (m *Metadata) MetadataVarNames() []string { return variables.All }
// Metadata implements the plugin.Provider interface.
func (m *Metadata) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, varName string) (interface{}, bool) {
if val, err := variables.GetValue(varName, w, r); err == nil {
return val, true
}
return nil, false
}

View file

@ -0,0 +1,79 @@
package metadata
import (
"context"
"testing"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
)
// testProvider implements fake Providers. Plugins which inmplement Provider interface
type testProvider map[string]interface{}
func (m testProvider) MetadataVarNames() []string {
keys := []string{}
for k := range m {
keys = append(keys, k)
}
return keys
}
func (m testProvider) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, key string) (val interface{}, ok bool) {
value, ok := m[key]
return value, ok
}
// testHandler implements plugin.Handler
type testHandler struct{ ctx context.Context }
func (m *testHandler) Name() string { return "testHandler" }
func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m.ctx = ctx
return 0, nil
}
func TestMetadataServDns(t *testing.T) {
expectedMetadata := []testProvider{
testProvider{"testkey1": "testvalue1"},
testProvider{"testkey2": 2, "testkey3": "testvalue3"},
}
// Create fake Providers based on expectedMetadata
providers := []Provider{}
for _, e := range expectedMetadata {
providers = append(providers, e)
}
// Fake handler which stores the resulting context
next := &testHandler{}
metadata := Metadata{
Zones: []string{"."},
Providers: providers,
Next: next,
}
metadata.ServeDNS(context.TODO(), &test.ResponseWriter{}, new(dns.Msg))
// Verify that next plugin can find metadata in context from all Providers
for _, expected := range expectedMetadata {
md, ok := FromContext(next.ctx)
if !ok {
t.Fatalf("Metadata is expected but not present inside the context")
}
for expKey, expVal := range expected {
metadataVal, valOk := md.Value(expKey)
if !valOk {
t.Fatalf("Value by key %v can't be retrieved", expKey)
}
if metadataVal != expVal {
t.Errorf("Expected value %v, but got %v", expVal, metadataVal)
}
}
wrongKey := "wrong_key"
metadataVal, ok := md.Value(wrongKey)
if ok {
t.Fatalf("Value by key %v is not expected to be recieved, but got: %v", wrongKey, metadataVal)
}
}
}

View file

@ -0,0 +1,53 @@
package metadata
import (
"context"
"github.com/miekg/dns"
)
// Provider interface needs to be implemented by each plugin willing to provide
// metadata information for other plugins.
// Note: this method should work quickly, because it is called for every request
// from the metadata plugin.
type Provider interface {
// List of variables which are provided by current Provider. Must remain constant.
MetadataVarNames() []string
// Metadata is expected to return a value with metadata information by the key
// from 4th argument. Value can be later retrieved from context by any other plugin.
// If value is not available by some reason returned boolean value should be false.
Metadata(context.Context, dns.ResponseWriter, *dns.Msg, string) (interface{}, bool)
}
// MD is metadata information storage
type MD map[string]interface{}
// metadataKey defines the type of key that is used to save metadata into the context
type metadataKey struct{}
// newMD initializes MD and attaches it to context
func newMD(ctx context.Context) (MD, context.Context) {
m := MD{}
return m, context.WithValue(ctx, metadataKey{}, m)
}
// FromContext retrieves MD struct from context.
func FromContext(ctx context.Context) (md MD, ok bool) {
if metadata := ctx.Value(metadataKey{}); metadata != nil {
if md, ok := metadata.(MD); ok {
return md, true
}
}
return MD{}, false
}
// Value returns metadata value by key.
func (m MD) Value(key string) (value interface{}, ok bool) {
value, ok = m[key]
return value, ok
}
// setValue adds metadata value.
func (m MD) setValue(key string, val interface{}) {
m[key] = val
}

View file

@ -0,0 +1,47 @@
package metadata
import (
"context"
"reflect"
"testing"
)
func TestMD(t *testing.T) {
tests := []struct {
addValues map[string]interface{}
expectedValues map[string]interface{}
}{
{
// Add initial metadata key/vals
map[string]interface{}{"key1": "val1", "key2": 2},
map[string]interface{}{"key1": "val1", "key2": 2},
},
{
// Add additional key/vals.
map[string]interface{}{"key3": 3, "key4": 4.5},
map[string]interface{}{"key1": "val1", "key2": 2, "key3": 3, "key4": 4.5},
},
}
// Using one same md and ctx for all test cases
ctx := context.TODO()
md, ctx := newMD(ctx)
for i, tc := range tests {
for k, v := range tc.addValues {
md.setValue(k, v)
}
if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(md)) {
t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, md)
}
// Make sure that MD is recieved from context successfullly
mdFromContext, ok := FromContext(ctx)
if !ok {
t.Errorf("Test %d: MD is not recieved from the context", i)
}
if !reflect.DeepEqual(md, mdFromContext) {
t.Errorf("Test %d: MD recieved from context differs from initial. Initial: %v, from context: %v", i, md, mdFromContext)
}
}
}

71
plugin/metadata/setup.go Normal file
View file

@ -0,0 +1,71 @@
package metadata
import (
"fmt"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/mholt/caddy"
)
func init() {
caddy.RegisterPlugin("metadata", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
m, err := metadataParse(c)
if err != nil {
return err
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
m.Next = next
return m
})
c.OnStartup(func() error {
plugins := dnsserver.GetConfig(c).Handlers()
// Collect all plugins which implement Provider interface
metadataVariables := map[string]bool{}
for _, p := range plugins {
if met, ok := p.(Provider); ok {
for _, varName := range met.MetadataVarNames() {
if _, ok := metadataVariables[varName]; ok {
return fmt.Errorf("Metadata variable '%v' has duplicates", varName)
}
metadataVariables[varName] = true
}
m.Providers = append(m.Providers, met)
}
}
return nil
})
return nil
}
func metadataParse(c *caddy.Controller) (*Metadata, error) {
m := &Metadata{}
c.Next()
zones := c.RemainingArgs()
if len(zones) != 0 {
m.Zones = zones
for i := 0; i < len(m.Zones); i++ {
m.Zones[i] = plugin.Host(m.Zones[i]).Normalize()
}
} else {
m.Zones = make([]string, len(c.ServerBlockKeys))
for i := 0; i < len(c.ServerBlockKeys); i++ {
m.Zones[i] = plugin.Host(c.ServerBlockKeys[i]).Normalize()
}
}
if c.NextBlock() || c.Next() {
return nil, plugin.Error("metadata", c.ArgErr())
}
return m, nil
}

View file

@ -0,0 +1,70 @@
package metadata
import (
"reflect"
"testing"
"github.com/mholt/caddy"
)
func TestSetup(t *testing.T) {
tests := []struct {
input string
zones []string
shouldErr bool
}{
{"metadata", []string{}, false},
{"metadata example.com.", []string{"example.com."}, false},
{"metadata example.com. net.", []string{"example.com.", "net."}, false},
{"metadata example.com. { some_param }", []string{}, true},
{"metadata\nmetadata", []string{}, true},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
err := setup(c)
if test.shouldErr && err == nil {
t.Errorf("Test %d: Setup call expected error but found none for input %s", i, test.input)
}
if !test.shouldErr && err != nil {
t.Errorf("Test %d: Setup call expected no error but found one for input %s. Error was: %v", i, test.input, err)
}
}
}
func TestSetupHealth(t *testing.T) {
tests := []struct {
input string
zones []string
shouldErr bool
}{
{"metadata", []string{}, false},
{"metadata example.com.", []string{"example.com."}, false},
{"metadata example.com. net.", []string{"example.com.", "net."}, false},
{"metadata example.com. { some_param }", []string{}, true},
{"metadata\nmetadata", []string{}, true},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
m, err := metadataParse(c)
if test.shouldErr && err == nil {
t.Errorf("Test %d: Expected error but found none for input %s", i, test.input)
}
if !test.shouldErr && err != nil {
t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err)
}
if !test.shouldErr && err == nil {
if !reflect.DeepEqual(test.zones, m.Zones) {
t.Errorf("Test %d: Expected zones %s. Zones were: %v", i, test.zones, m.Zones)
}
}
}
}

View file

@ -0,0 +1,109 @@
package variables
import (
"encoding/binary"
"fmt"
"net"
"strconv"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
const (
queryName = "qname"
queryType = "qtype"
clientIP = "client_ip"
clientPort = "client_port"
protocol = "protocol"
serverIP = "server_ip"
serverPort = "server_port"
)
// All is a list of available variables provided by GetMetadataValue
var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverIP, serverPort}
// GetValue calculates and returns the data specified by the variable name.
// Supported varNames are listed in allProvidedVars.
func GetValue(varName string, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
req := request.Request{W: w, Req: r}
switch varName {
case queryName:
//Query name is written as ascii string
return []byte(req.QName()), nil
case queryType:
return uint16ToWire(req.QType()), nil
case clientIP:
return ipToWire(req.Family(), req.IP())
case clientPort:
return portToWire(req.Port())
case protocol:
// Proto is written as ascii string
return []byte(req.Proto()), nil
case serverIP:
ip, _, err := net.SplitHostPort(w.LocalAddr().String())
if err != nil {
ip = w.RemoteAddr().String()
}
return ipToWire(family(w.RemoteAddr()), ip)
case serverPort:
_, port, err := net.SplitHostPort(w.LocalAddr().String())
if err != nil {
port = "0"
}
return portToWire(port)
}
return nil, fmt.Errorf("unable to extract data for variable %s", varName)
}
// uint16ToWire writes unit16 to wire/binary format
func uint16ToWire(data uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(data))
return buf
}
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
func ipToWire(family int, ipAddr string) ([]byte, error) {
switch family {
case 1:
return net.ParseIP(ipAddr).To4(), nil
case 2:
return net.ParseIP(ipAddr).To16(), nil
}
return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
}
// portToWire writes port to wire/binary format, 2 bytes
func portToWire(portStr string) ([]byte, error) {
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, err
}
return uint16ToWire(uint16(port)), nil
}
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
func family(ip net.Addr) int {
var a net.IP
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
if i, ok := ip.(*net.TCPAddr); ok {
a = i.IP
}
if a.To4() != nil {
return 1
}
return 2
}

View file

@ -0,0 +1,80 @@
package variables
import (
"bytes"
"testing"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
)
func TestGetValue(t *testing.T) {
// test.ResponseWriter has the following values:
// The remote will always be 10.240.0.1 and port 40212.
// The local address is always 127.0.0.1 and port 53.
tests := []struct {
varName string
expectedValue []byte
shouldErr bool
}{
{
queryName,
[]byte("example.com."),
false,
},
{
queryType,
[]byte{0x00, 0x01},
false,
},
{
clientIP,
[]byte{10, 240, 0, 1},
false,
},
{
clientPort,
[]byte{0x9D, 0x14},
false,
},
{
protocol,
[]byte("udp"),
false,
},
{
serverIP,
[]byte{127, 0, 0, 1},
false,
},
{
serverPort,
[]byte{0, 53},
false,
},
{
"wrong_var",
[]byte{},
true,
},
}
for i, tc := range tests {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET
value, err := GetValue(tc.varName, &test.ResponseWriter{}, m)
if tc.shouldErr && err == nil {
t.Errorf("Test %d: Expected error, but didn't recieve", i)
}
if !tc.shouldErr && err != nil {
t.Errorf("Test %d: Expected no error, but got error: %v", i, err.Error())
}
if !bytes.Equal(tc.expectedValue, value) {
t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValue, value)
}
}
}

View file

@ -206,13 +206,17 @@ rewrites the first local option with code 0xffee, setting the data to "abcd". Eq
}
~~~
* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables:
* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables by default:
{qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}.
Any plugin that can provide it's own additional variables by implementing metadata.Provider interface. If you are going to use metadata variables then metadata plugin must be enabled.
Example:
~~~
rewrite edns0 local set 0xffee {client_ip}
~~~ corefile
. {
metadata
rewrite edns0 local set 0xffee {client_ip}
}
~~~
### EDNS0_NSID

View file

@ -1,6 +1,7 @@
package rewrite
import (
"context"
"fmt"
"strings"
@ -27,7 +28,7 @@ func newClassRule(nextAction string, args ...string) (Rule, error) {
}
// Rewrite rewrites the the current request.
func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *classRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
if rule.fromClass > 0 && rule.toClass > 0 {
if r.Question[0].Qclass == rule.fromClass {
r.Question[0].Qclass = rule.toClass

View file

@ -2,13 +2,15 @@
package rewrite
import (
"encoding/binary"
"context"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
"github.com/coredns/coredns/plugin/metadata"
"github.com/coredns/coredns/plugin/pkg/variables"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
@ -46,7 +48,7 @@ func setupEdns0Opt(r *dns.Msg) *dns.OPT {
}
// Rewrite will alter the request EDNS0 NSID option
func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *edns0NsidRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
@ -83,7 +85,7 @@ func (rule *edns0NsidRule) GetResponseRule() ResponseRule {
}
// Rewrite will alter the request EDNS0 local options
func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *edns0LocalRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
@ -146,7 +148,9 @@ func newEdns0Rule(mode string, args ...string) (Rule, error) {
}
//Check for variable option
if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") {
return newEdns0VariableRule(mode, action, args[2], args[3])
// Remove first and last runes
variable := args[3][1 : len(args[3])-1]
return newEdns0VariableRule(mode, action, args[2], variable)
}
return newEdns0LocalRule(mode, action, args[2], args[3])
case "nsid":
@ -186,102 +190,28 @@ func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRu
if err != nil {
return nil, err
}
//Validate
if !isValidVariable(variable) {
return nil, fmt.Errorf("unsupported variable name %q", variable)
}
return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil
}
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) {
switch family {
case 1:
return net.ParseIP(ipAddr).To4(), nil
case 2:
return net.ParseIP(ipAddr).To16(), nil
}
return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
}
// uint16ToWire writes unit16 to wire/binary format
func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(data))
return buf
}
// portToWire writes port to wire/binary format, 2 bytes
func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) {
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return nil, err
}
return rule.uint16ToWire(uint16(port)), nil
}
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
func (rule *edns0VariableRule) family(ip net.Addr) int {
var a net.IP
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
if i, ok := ip.(*net.TCPAddr); ok {
a = i.IP
}
if a.To4() != nil {
return 1
}
return 2
}
// ruleData returns the data specified by the variable
func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
req := request.Request{W: w, Req: r}
switch rule.variable {
case queryName:
//Query name is written as ascii string
return []byte(req.QName()), nil
case queryType:
return rule.uint16ToWire(req.QType()), nil
case clientIP:
return rule.ipToWire(req.Family(), req.IP())
case clientPort:
return rule.portToWire(req.Port())
case protocol:
// Proto is written as ascii string
return []byte(req.Proto()), nil
case serverIP:
ip, _, err := net.SplitHostPort(w.LocalAddr().String())
if err != nil {
ip = w.RemoteAddr().String()
func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
if md, ok := metadata.FromContext(ctx); ok {
if value, ok := md.Value(rule.variable); ok {
if v, ok := value.([]byte); ok {
return v, nil
}
return rule.ipToWire(rule.family(w.RemoteAddr()), ip)
case serverPort:
_, port, err := net.SplitHostPort(w.LocalAddr().String())
if err != nil {
port = "0"
}
return rule.portToWire(port)
} else { // No metadata available means metadata plugin is disabled. Try to get the value directly.
return variables.GetValue(rule.variable, w, r)
}
return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable)
}
// Rewrite will alter the request EDNS0 local options with specified variables
func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *edns0VariableRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
data, err := rule.ruleData(w, r)
data, err := rule.ruleData(ctx, w, r)
if err != nil || data == nil {
return result
}
@ -324,21 +254,6 @@ func (rule *edns0VariableRule) GetResponseRule() ResponseRule {
return ResponseRule{}
}
func isValidVariable(variable string) bool {
switch variable {
case
queryName,
queryType,
clientIP,
clientPort,
protocol,
serverIP,
serverPort:
return true
}
return false
}
// ends0SubnetRule is a rewrite rule for EDNS0 subnet options
type edns0SubnetRule struct {
mode string
@ -400,7 +315,7 @@ func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg,
}
// Rewrite will alter the request EDNS0 subnet option
func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *edns0SubnetRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
result := RewriteIgnored
o := setupEdns0Opt(r)
found := false
@ -446,17 +361,6 @@ const (
Append = "append"
)
// Supported local EDNS0 variables
const (
queryName = "{qname}"
queryType = "{qtype}"
clientIP = "{client_ip}"
clientPort = "{client_port}"
protocol = "{protocol}"
serverIP = "{server_ip}"
serverPort = "{server_port}"
)
// Subnet maximum bit mask length
const (
maxV4BitMaskLen = 32

View file

@ -1,6 +1,7 @@
package rewrite
import (
"context"
"fmt"
"regexp"
"strconv"
@ -56,7 +57,7 @@ const (
// Rewrite rewrites the current request based upon exact match of the name
// in the question section of the request
func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
if rule.From == r.Question[0].Name {
r.Question[0].Name = rule.To
return RewriteDone
@ -65,7 +66,7 @@ func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
}
// Rewrite rewrites the current request when the name begins with the matching string
func (rule *prefixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
if strings.HasPrefix(r.Question[0].Name, rule.Prefix) {
r.Question[0].Name = rule.Replacement + strings.TrimLeft(r.Question[0].Name, rule.Prefix)
return RewriteDone
@ -74,7 +75,7 @@ func (rule *prefixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
}
// Rewrite rewrites the current request when the name ends with the matching string
func (rule *suffixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
if strings.HasSuffix(r.Question[0].Name, rule.Suffix) {
r.Question[0].Name = strings.TrimRight(r.Question[0].Name, rule.Suffix) + rule.Replacement
return RewriteDone
@ -84,7 +85,7 @@ func (rule *suffixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
// Rewrite rewrites the current request based upon partial match of the
// name in the question section of the request
func (rule *substringNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
if strings.Contains(r.Question[0].Name, rule.Substring) {
r.Question[0].Name = strings.Replace(r.Question[0].Name, rule.Substring, rule.Replacement, -1)
return RewriteDone
@ -94,7 +95,7 @@ func (rule *substringNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result
// Rewrite rewrites the current request when the name in the question
// section of the request matches a regular expression
func (rule *regexNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *regexNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
regexGroups := rule.Pattern.FindStringSubmatch(r.Question[0].Name)
if len(regexGroups) == 0 {
return RewriteIgnored

View file

@ -42,7 +42,7 @@ type Rewrite struct {
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
wr := NewResponseReverter(w, r)
for _, rule := range rw.Rules {
switch result := rule.Rewrite(w, r); result {
switch result := rule.Rewrite(ctx, w, r); result {
case RewriteDone:
respRule := rule.GetResponseRule()
if respRule.Active == true {
@ -76,7 +76,7 @@ func (rw Rewrite) Name() string { return "rewrite" }
// Rule describes a rewrite rule.
type Rule interface {
// Rewrite rewrites the current request.
Rewrite(dns.ResponseWriter, *dns.Msg) Result
Rewrite(context.Context, dns.ResponseWriter, *dns.Msg) Result
// Mode returns the processing mode stop or continue.
Mode() string
// GetResponseRule returns the rule to rewrite response with, if any.

View file

@ -71,7 +71,7 @@ func TestNewRule(t *testing.T) {
{[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})},
{[]string{"edns0", "nsid", "foo"}, true, nil},
{[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
@ -79,7 +79,7 @@ func TestNewRule(t *testing.T) {
{[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
@ -87,7 +87,7 @@ func TestNewRule(t *testing.T) {
{[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil},
{[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
{[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},

View file

@ -2,6 +2,7 @@
package rewrite
import (
"context"
"fmt"
"strings"
@ -28,7 +29,7 @@ func newTypeRule(nextAction string, args ...string) (Rule, error) {
}
// Rewrite rewrites the the current request.
func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
func (rule *typeRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
if rule.fromType > 0 && rule.toType > 0 {
if r.Question[0].Qtype == rule.fromType {
r.Question[0].Qtype = rule.toType