plugin/metadata: add metadata plugin (#1894)
* plugin/metadata: add metadata plugin * plugin/metadata: Add MD struct, refactor code, fix doc * plugin/metadata: simplify metadata key * plugin/metadata: improve setup_test * Support of metadata by rewrite plugin. Move calculated variables to metadata. * Move variables from metadata to pkg, add UTs, READMEs change, metadata small fixes * Add client port validation to variables_test * plugin/metadata: improve README * plugin/metadata: rename methods * plugin/metadata: Update Metadataer interface, update doc, cosmetic code changes * plugin/metadata: move colllisions check to OnStartup(). Fix default variables metadataer. * plugin/metadata: Fix comment for method setValue * plugin/metadata: change variables order to fix linter warning * plugin/metadata: rename Metadataer to Provider
This commit is contained in:
parent
dae506b563
commit
17d807f05f
19 changed files with 655 additions and 130 deletions
|
@ -10,6 +10,7 @@ package dnsserver
|
||||||
// (after) them during a request, but they must not
|
// (after) them during a request, but they must not
|
||||||
// care what plugin above them are doing.
|
// care what plugin above them are doing.
|
||||||
var Directives = []string{
|
var Directives = []string{
|
||||||
|
"metadata",
|
||||||
"tls",
|
"tls",
|
||||||
"reload",
|
"reload",
|
||||||
"nsid",
|
"nsid",
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
_ "github.com/coredns/coredns/plugin/kubernetes"
|
_ "github.com/coredns/coredns/plugin/kubernetes"
|
||||||
_ "github.com/coredns/coredns/plugin/loadbalance"
|
_ "github.com/coredns/coredns/plugin/loadbalance"
|
||||||
_ "github.com/coredns/coredns/plugin/log"
|
_ "github.com/coredns/coredns/plugin/log"
|
||||||
|
_ "github.com/coredns/coredns/plugin/metadata"
|
||||||
_ "github.com/coredns/coredns/plugin/metrics"
|
_ "github.com/coredns/coredns/plugin/metrics"
|
||||||
_ "github.com/coredns/coredns/plugin/nsid"
|
_ "github.com/coredns/coredns/plugin/nsid"
|
||||||
_ "github.com/coredns/coredns/plugin/pprof"
|
_ "github.com/coredns/coredns/plugin/pprof"
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
# Local plugin example:
|
# Local plugin example:
|
||||||
# log:log
|
# log:log
|
||||||
|
|
||||||
|
metadata:metadata
|
||||||
tls:tls
|
tls:tls
|
||||||
reload:reload
|
reload:reload
|
||||||
nsid:nsid
|
nsid:nsid
|
||||||
|
|
47
plugin/metadata/README.md
Normal file
47
plugin/metadata/README.md
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
# metadata
|
||||||
|
|
||||||
|
## Name
|
||||||
|
|
||||||
|
*metadata* - enable a metadata collector.
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
By enabling *metadata* any plugin that implements [metadata.Provider interface](https://godoc.org/github.com/coredns/coredns/plugin/metadata#Provider) will be called for each DNS query, at being of the process for that query, in order to add it's own Metadata to context. The metadata collected will be available for all plugins handler, via the Context parameter provided in the ServeDNS function.
|
||||||
|
Metadata plugin is automatically adding the so-called default medatada (extracted from the query) to the context. Those default metadata are: {qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}
|
||||||
|
|
||||||
|
|
||||||
|
## Syntax
|
||||||
|
|
||||||
|
~~~
|
||||||
|
metadata [ZONES... ]
|
||||||
|
~~~
|
||||||
|
|
||||||
|
## Plugins
|
||||||
|
|
||||||
|
metadata.Provider interface needs to be implemented by each plugin willing to provide metadata information for other plugins. It will be called by metadata and gather the information from all plugins in context.
|
||||||
|
Note: this method should work quickly, because it is called for every request
|
||||||
|
from the metadata plugin.
|
||||||
|
If **ZONES** is specified then metadata add is limited by zones. Metadata is added to every context going through metadata.Provider if **ZONES** are not specified.
|
||||||
|
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
Enable metadata for all requests. Rewrite uses one of the provided by default metadata variables.
|
||||||
|
|
||||||
|
~~~ corefile
|
||||||
|
. {
|
||||||
|
metadata
|
||||||
|
rewrite edns0 local set 0xffee {client_ip}
|
||||||
|
forward . 8.8.8.8:53
|
||||||
|
}
|
||||||
|
~~~
|
||||||
|
|
||||||
|
Add metadata for all requests within `example.org.`. Rewrite uses one of provided by default metadata variables. Any other requests won't have metadata.
|
||||||
|
|
||||||
|
~~~ corefile
|
||||||
|
. {
|
||||||
|
metadata example.org
|
||||||
|
rewrite edns0 local set 0xffee {client_ip}
|
||||||
|
forward . 8.8.8.8:53
|
||||||
|
}
|
||||||
|
~~~
|
55
plugin/metadata/metadata.go
Normal file
55
plugin/metadata/metadata.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
package metadata
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/plugin"
|
||||||
|
"github.com/coredns/coredns/plugin/pkg/variables"
|
||||||
|
"github.com/coredns/coredns/request"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Metadata implements collecting metadata information from all plugins that
|
||||||
|
// implement the Provider interface.
|
||||||
|
type Metadata struct {
|
||||||
|
Zones []string
|
||||||
|
Providers []Provider
|
||||||
|
Next plugin.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name implements the Handler interface.
|
||||||
|
func (m *Metadata) Name() string { return "metadata" }
|
||||||
|
|
||||||
|
// ServeDNS implements the plugin.Handler interface.
|
||||||
|
func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
|
||||||
|
md, ctx := newMD(ctx)
|
||||||
|
|
||||||
|
state := request.Request{W: w, Req: r}
|
||||||
|
if plugin.Zones(m.Zones).Matches(state.Name()) != "" {
|
||||||
|
// Go through all Providers and collect metadata
|
||||||
|
for _, provider := range m.Providers {
|
||||||
|
for _, varName := range provider.MetadataVarNames() {
|
||||||
|
if val, ok := provider.Metadata(ctx, w, r, varName); ok {
|
||||||
|
md.setValue(varName, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rcode, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, w, r)
|
||||||
|
|
||||||
|
return rcode, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetadataVarNames implements the plugin.Provider interface.
|
||||||
|
func (m *Metadata) MetadataVarNames() []string { return variables.All }
|
||||||
|
|
||||||
|
// Metadata implements the plugin.Provider interface.
|
||||||
|
func (m *Metadata) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, varName string) (interface{}, bool) {
|
||||||
|
if val, err := variables.GetValue(varName, w, r); err == nil {
|
||||||
|
return val, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
79
plugin/metadata/metadata_test.go
Normal file
79
plugin/metadata/metadata_test.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
package metadata
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/plugin/test"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testProvider implements fake Providers. Plugins which inmplement Provider interface
|
||||||
|
type testProvider map[string]interface{}
|
||||||
|
|
||||||
|
func (m testProvider) MetadataVarNames() []string {
|
||||||
|
keys := []string{}
|
||||||
|
for k := range m {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m testProvider) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, key string) (val interface{}, ok bool) {
|
||||||
|
value, ok := m[key]
|
||||||
|
return value, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// testHandler implements plugin.Handler
|
||||||
|
type testHandler struct{ ctx context.Context }
|
||||||
|
|
||||||
|
func (m *testHandler) Name() string { return "testHandler" }
|
||||||
|
|
||||||
|
func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
m.ctx = ctx
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetadataServDns(t *testing.T) {
|
||||||
|
expectedMetadata := []testProvider{
|
||||||
|
testProvider{"testkey1": "testvalue1"},
|
||||||
|
testProvider{"testkey2": 2, "testkey3": "testvalue3"},
|
||||||
|
}
|
||||||
|
// Create fake Providers based on expectedMetadata
|
||||||
|
providers := []Provider{}
|
||||||
|
for _, e := range expectedMetadata {
|
||||||
|
providers = append(providers, e)
|
||||||
|
}
|
||||||
|
// Fake handler which stores the resulting context
|
||||||
|
next := &testHandler{}
|
||||||
|
|
||||||
|
metadata := Metadata{
|
||||||
|
Zones: []string{"."},
|
||||||
|
Providers: providers,
|
||||||
|
Next: next,
|
||||||
|
}
|
||||||
|
metadata.ServeDNS(context.TODO(), &test.ResponseWriter{}, new(dns.Msg))
|
||||||
|
|
||||||
|
// Verify that next plugin can find metadata in context from all Providers
|
||||||
|
for _, expected := range expectedMetadata {
|
||||||
|
md, ok := FromContext(next.ctx)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Metadata is expected but not present inside the context")
|
||||||
|
}
|
||||||
|
for expKey, expVal := range expected {
|
||||||
|
metadataVal, valOk := md.Value(expKey)
|
||||||
|
if !valOk {
|
||||||
|
t.Fatalf("Value by key %v can't be retrieved", expKey)
|
||||||
|
}
|
||||||
|
if metadataVal != expVal {
|
||||||
|
t.Errorf("Expected value %v, but got %v", expVal, metadataVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wrongKey := "wrong_key"
|
||||||
|
metadataVal, ok := md.Value(wrongKey)
|
||||||
|
if ok {
|
||||||
|
t.Fatalf("Value by key %v is not expected to be recieved, but got: %v", wrongKey, metadataVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
53
plugin/metadata/metadataer.go
Normal file
53
plugin/metadata/metadataer.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package metadata
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider interface needs to be implemented by each plugin willing to provide
|
||||||
|
// metadata information for other plugins.
|
||||||
|
// Note: this method should work quickly, because it is called for every request
|
||||||
|
// from the metadata plugin.
|
||||||
|
type Provider interface {
|
||||||
|
// List of variables which are provided by current Provider. Must remain constant.
|
||||||
|
MetadataVarNames() []string
|
||||||
|
// Metadata is expected to return a value with metadata information by the key
|
||||||
|
// from 4th argument. Value can be later retrieved from context by any other plugin.
|
||||||
|
// If value is not available by some reason returned boolean value should be false.
|
||||||
|
Metadata(context.Context, dns.ResponseWriter, *dns.Msg, string) (interface{}, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MD is metadata information storage
|
||||||
|
type MD map[string]interface{}
|
||||||
|
|
||||||
|
// metadataKey defines the type of key that is used to save metadata into the context
|
||||||
|
type metadataKey struct{}
|
||||||
|
|
||||||
|
// newMD initializes MD and attaches it to context
|
||||||
|
func newMD(ctx context.Context) (MD, context.Context) {
|
||||||
|
m := MD{}
|
||||||
|
return m, context.WithValue(ctx, metadataKey{}, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext retrieves MD struct from context.
|
||||||
|
func FromContext(ctx context.Context) (md MD, ok bool) {
|
||||||
|
if metadata := ctx.Value(metadataKey{}); metadata != nil {
|
||||||
|
if md, ok := metadata.(MD); ok {
|
||||||
|
return md, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MD{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value returns metadata value by key.
|
||||||
|
func (m MD) Value(key string) (value interface{}, ok bool) {
|
||||||
|
value, ok = m[key]
|
||||||
|
return value, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// setValue adds metadata value.
|
||||||
|
func (m MD) setValue(key string, val interface{}) {
|
||||||
|
m[key] = val
|
||||||
|
}
|
47
plugin/metadata/metadataer_test.go
Normal file
47
plugin/metadata/metadataer_test.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package metadata
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMD(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
addValues map[string]interface{}
|
||||||
|
expectedValues map[string]interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
// Add initial metadata key/vals
|
||||||
|
map[string]interface{}{"key1": "val1", "key2": 2},
|
||||||
|
map[string]interface{}{"key1": "val1", "key2": 2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Add additional key/vals.
|
||||||
|
map[string]interface{}{"key3": 3, "key4": 4.5},
|
||||||
|
map[string]interface{}{"key1": "val1", "key2": 2, "key3": 3, "key4": 4.5},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Using one same md and ctx for all test cases
|
||||||
|
ctx := context.TODO()
|
||||||
|
md, ctx := newMD(ctx)
|
||||||
|
|
||||||
|
for i, tc := range tests {
|
||||||
|
for k, v := range tc.addValues {
|
||||||
|
md.setValue(k, v)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(md)) {
|
||||||
|
t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, md)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure that MD is recieved from context successfullly
|
||||||
|
mdFromContext, ok := FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Test %d: MD is not recieved from the context", i)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(md, mdFromContext) {
|
||||||
|
t.Errorf("Test %d: MD recieved from context differs from initial. Initial: %v, from context: %v", i, md, mdFromContext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
71
plugin/metadata/setup.go
Normal file
71
plugin/metadata/setup.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package metadata
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/core/dnsserver"
|
||||||
|
"github.com/coredns/coredns/plugin"
|
||||||
|
|
||||||
|
"github.com/mholt/caddy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
caddy.RegisterPlugin("metadata", caddy.Plugin{
|
||||||
|
ServerType: "dns",
|
||||||
|
Action: setup,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func setup(c *caddy.Controller) error {
|
||||||
|
m, err := metadataParse(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||||
|
m.Next = next
|
||||||
|
return m
|
||||||
|
})
|
||||||
|
|
||||||
|
c.OnStartup(func() error {
|
||||||
|
plugins := dnsserver.GetConfig(c).Handlers()
|
||||||
|
// Collect all plugins which implement Provider interface
|
||||||
|
metadataVariables := map[string]bool{}
|
||||||
|
for _, p := range plugins {
|
||||||
|
if met, ok := p.(Provider); ok {
|
||||||
|
for _, varName := range met.MetadataVarNames() {
|
||||||
|
if _, ok := metadataVariables[varName]; ok {
|
||||||
|
return fmt.Errorf("Metadata variable '%v' has duplicates", varName)
|
||||||
|
}
|
||||||
|
metadataVariables[varName] = true
|
||||||
|
}
|
||||||
|
m.Providers = append(m.Providers, met)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func metadataParse(c *caddy.Controller) (*Metadata, error) {
|
||||||
|
m := &Metadata{}
|
||||||
|
c.Next()
|
||||||
|
zones := c.RemainingArgs()
|
||||||
|
|
||||||
|
if len(zones) != 0 {
|
||||||
|
m.Zones = zones
|
||||||
|
for i := 0; i < len(m.Zones); i++ {
|
||||||
|
m.Zones[i] = plugin.Host(m.Zones[i]).Normalize()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
m.Zones = make([]string, len(c.ServerBlockKeys))
|
||||||
|
for i := 0; i < len(c.ServerBlockKeys); i++ {
|
||||||
|
m.Zones[i] = plugin.Host(c.ServerBlockKeys[i]).Normalize()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.NextBlock() || c.Next() {
|
||||||
|
return nil, plugin.Error("metadata", c.ArgErr())
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
70
plugin/metadata/setup_test.go
Normal file
70
plugin/metadata/setup_test.go
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
package metadata
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mholt/caddy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetup(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
zones []string
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{"metadata", []string{}, false},
|
||||||
|
{"metadata example.com.", []string{"example.com."}, false},
|
||||||
|
{"metadata example.com. net.", []string{"example.com.", "net."}, false},
|
||||||
|
|
||||||
|
{"metadata example.com. { some_param }", []string{}, true},
|
||||||
|
{"metadata\nmetadata", []string{}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
c := caddy.NewTestController("dns", test.input)
|
||||||
|
err := setup(c)
|
||||||
|
|
||||||
|
if test.shouldErr && err == nil {
|
||||||
|
t.Errorf("Test %d: Setup call expected error but found none for input %s", i, test.input)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !test.shouldErr && err != nil {
|
||||||
|
t.Errorf("Test %d: Setup call expected no error but found one for input %s. Error was: %v", i, test.input, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupHealth(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
zones []string
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{"metadata", []string{}, false},
|
||||||
|
{"metadata example.com.", []string{"example.com."}, false},
|
||||||
|
{"metadata example.com. net.", []string{"example.com.", "net."}, false},
|
||||||
|
|
||||||
|
{"metadata example.com. { some_param }", []string{}, true},
|
||||||
|
{"metadata\nmetadata", []string{}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
c := caddy.NewTestController("dns", test.input)
|
||||||
|
m, err := metadataParse(c)
|
||||||
|
|
||||||
|
if test.shouldErr && err == nil {
|
||||||
|
t.Errorf("Test %d: Expected error but found none for input %s", i, test.input)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !test.shouldErr && err != nil {
|
||||||
|
t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !test.shouldErr && err == nil {
|
||||||
|
if !reflect.DeepEqual(test.zones, m.Zones) {
|
||||||
|
t.Errorf("Test %d: Expected zones %s. Zones were: %v", i, test.zones, m.Zones)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
109
plugin/pkg/variables/variables.go
Normal file
109
plugin/pkg/variables/variables.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
package variables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/request"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
queryName = "qname"
|
||||||
|
queryType = "qtype"
|
||||||
|
clientIP = "client_ip"
|
||||||
|
clientPort = "client_port"
|
||||||
|
protocol = "protocol"
|
||||||
|
serverIP = "server_ip"
|
||||||
|
serverPort = "server_port"
|
||||||
|
)
|
||||||
|
|
||||||
|
// All is a list of available variables provided by GetMetadataValue
|
||||||
|
var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverIP, serverPort}
|
||||||
|
|
||||||
|
// GetValue calculates and returns the data specified by the variable name.
|
||||||
|
// Supported varNames are listed in allProvidedVars.
|
||||||
|
func GetValue(varName string, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
|
||||||
|
req := request.Request{W: w, Req: r}
|
||||||
|
switch varName {
|
||||||
|
case queryName:
|
||||||
|
//Query name is written as ascii string
|
||||||
|
return []byte(req.QName()), nil
|
||||||
|
|
||||||
|
case queryType:
|
||||||
|
return uint16ToWire(req.QType()), nil
|
||||||
|
|
||||||
|
case clientIP:
|
||||||
|
return ipToWire(req.Family(), req.IP())
|
||||||
|
|
||||||
|
case clientPort:
|
||||||
|
return portToWire(req.Port())
|
||||||
|
|
||||||
|
case protocol:
|
||||||
|
// Proto is written as ascii string
|
||||||
|
return []byte(req.Proto()), nil
|
||||||
|
|
||||||
|
case serverIP:
|
||||||
|
ip, _, err := net.SplitHostPort(w.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
ip = w.RemoteAddr().String()
|
||||||
|
}
|
||||||
|
return ipToWire(family(w.RemoteAddr()), ip)
|
||||||
|
|
||||||
|
case serverPort:
|
||||||
|
_, port, err := net.SplitHostPort(w.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
port = "0"
|
||||||
|
}
|
||||||
|
return portToWire(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to extract data for variable %s", varName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// uint16ToWire writes unit16 to wire/binary format
|
||||||
|
func uint16ToWire(data uint16) []byte {
|
||||||
|
buf := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(buf, uint16(data))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
|
||||||
|
func ipToWire(family int, ipAddr string) ([]byte, error) {
|
||||||
|
|
||||||
|
switch family {
|
||||||
|
case 1:
|
||||||
|
return net.ParseIP(ipAddr).To4(), nil
|
||||||
|
case 2:
|
||||||
|
return net.ParseIP(ipAddr).To16(), nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
|
||||||
|
}
|
||||||
|
|
||||||
|
// portToWire writes port to wire/binary format, 2 bytes
|
||||||
|
func portToWire(portStr string) ([]byte, error) {
|
||||||
|
|
||||||
|
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return uint16ToWire(uint16(port)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
|
||||||
|
func family(ip net.Addr) int {
|
||||||
|
var a net.IP
|
||||||
|
if i, ok := ip.(*net.UDPAddr); ok {
|
||||||
|
a = i.IP
|
||||||
|
}
|
||||||
|
if i, ok := ip.(*net.TCPAddr); ok {
|
||||||
|
a = i.IP
|
||||||
|
}
|
||||||
|
if a.To4() != nil {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 2
|
||||||
|
}
|
80
plugin/pkg/variables/variables_test.go
Normal file
80
plugin/pkg/variables/variables_test.go
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
package variables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/plugin/test"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetValue(t *testing.T) {
|
||||||
|
// test.ResponseWriter has the following values:
|
||||||
|
// The remote will always be 10.240.0.1 and port 40212.
|
||||||
|
// The local address is always 127.0.0.1 and port 53.
|
||||||
|
tests := []struct {
|
||||||
|
varName string
|
||||||
|
expectedValue []byte
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
queryName,
|
||||||
|
[]byte("example.com."),
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
queryType,
|
||||||
|
[]byte{0x00, 0x01},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
clientIP,
|
||||||
|
[]byte{10, 240, 0, 1},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
clientPort,
|
||||||
|
[]byte{0x9D, 0x14},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
protocol,
|
||||||
|
[]byte("udp"),
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverIP,
|
||||||
|
[]byte{127, 0, 0, 1},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverPort,
|
||||||
|
[]byte{0, 53},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wrong_var",
|
||||||
|
[]byte{},
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range tests {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
m.Question[0].Qclass = dns.ClassINET
|
||||||
|
|
||||||
|
value, err := GetValue(tc.varName, &test.ResponseWriter{}, m)
|
||||||
|
|
||||||
|
if tc.shouldErr && err == nil {
|
||||||
|
t.Errorf("Test %d: Expected error, but didn't recieve", i)
|
||||||
|
}
|
||||||
|
if !tc.shouldErr && err != nil {
|
||||||
|
t.Errorf("Test %d: Expected no error, but got error: %v", i, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(tc.expectedValue, value) {
|
||||||
|
t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValue, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -206,13 +206,17 @@ rewrites the first local option with code 0xffee, setting the data to "abcd". Eq
|
||||||
}
|
}
|
||||||
~~~
|
~~~
|
||||||
|
|
||||||
* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables:
|
* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables by default:
|
||||||
{qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}.
|
{qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}.
|
||||||
|
Any plugin that can provide it's own additional variables by implementing metadata.Provider interface. If you are going to use metadata variables then metadata plugin must be enabled.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
~~~
|
~~~ corefile
|
||||||
rewrite edns0 local set 0xffee {client_ip}
|
. {
|
||||||
|
metadata
|
||||||
|
rewrite edns0 local set 0xffee {client_ip}
|
||||||
|
}
|
||||||
~~~
|
~~~
|
||||||
|
|
||||||
### EDNS0_NSID
|
### EDNS0_NSID
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package rewrite
|
package rewrite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -27,7 +28,7 @@ func newClassRule(nextAction string, args ...string) (Rule, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite rewrites the the current request.
|
// Rewrite rewrites the the current request.
|
||||||
func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *classRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
if rule.fromClass > 0 && rule.toClass > 0 {
|
if rule.fromClass > 0 && rule.toClass > 0 {
|
||||||
if r.Question[0].Qclass == rule.fromClass {
|
if r.Question[0].Qclass == rule.fromClass {
|
||||||
r.Question[0].Qclass = rule.toClass
|
r.Question[0].Qclass = rule.toClass
|
||||||
|
|
|
@ -2,13 +2,15 @@
|
||||||
package rewrite
|
package rewrite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"context"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/plugin/metadata"
|
||||||
|
"github.com/coredns/coredns/plugin/pkg/variables"
|
||||||
"github.com/coredns/coredns/request"
|
"github.com/coredns/coredns/request"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
@ -46,7 +48,7 @@ func setupEdns0Opt(r *dns.Msg) *dns.OPT {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite will alter the request EDNS0 NSID option
|
// Rewrite will alter the request EDNS0 NSID option
|
||||||
func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *edns0NsidRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
result := RewriteIgnored
|
result := RewriteIgnored
|
||||||
o := setupEdns0Opt(r)
|
o := setupEdns0Opt(r)
|
||||||
found := false
|
found := false
|
||||||
|
@ -83,7 +85,7 @@ func (rule *edns0NsidRule) GetResponseRule() ResponseRule {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite will alter the request EDNS0 local options
|
// Rewrite will alter the request EDNS0 local options
|
||||||
func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *edns0LocalRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
result := RewriteIgnored
|
result := RewriteIgnored
|
||||||
o := setupEdns0Opt(r)
|
o := setupEdns0Opt(r)
|
||||||
found := false
|
found := false
|
||||||
|
@ -146,7 +148,9 @@ func newEdns0Rule(mode string, args ...string) (Rule, error) {
|
||||||
}
|
}
|
||||||
//Check for variable option
|
//Check for variable option
|
||||||
if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") {
|
if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") {
|
||||||
return newEdns0VariableRule(mode, action, args[2], args[3])
|
// Remove first and last runes
|
||||||
|
variable := args[3][1 : len(args[3])-1]
|
||||||
|
return newEdns0VariableRule(mode, action, args[2], variable)
|
||||||
}
|
}
|
||||||
return newEdns0LocalRule(mode, action, args[2], args[3])
|
return newEdns0LocalRule(mode, action, args[2], args[3])
|
||||||
case "nsid":
|
case "nsid":
|
||||||
|
@ -186,102 +190,28 @@ func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
//Validate
|
|
||||||
if !isValidVariable(variable) {
|
|
||||||
return nil, fmt.Errorf("unsupported variable name %q", variable)
|
|
||||||
}
|
|
||||||
return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil
|
return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6.
|
|
||||||
func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) {
|
|
||||||
|
|
||||||
switch family {
|
|
||||||
case 1:
|
|
||||||
return net.ParseIP(ipAddr).To4(), nil
|
|
||||||
case 2:
|
|
||||||
return net.ParseIP(ipAddr).To16(), nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family)
|
|
||||||
}
|
|
||||||
|
|
||||||
// uint16ToWire writes unit16 to wire/binary format
|
|
||||||
func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte {
|
|
||||||
buf := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(buf, uint16(data))
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
// portToWire writes port to wire/binary format, 2 bytes
|
|
||||||
func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) {
|
|
||||||
|
|
||||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return rule.uint16ToWire(uint16(port)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
|
|
||||||
func (rule *edns0VariableRule) family(ip net.Addr) int {
|
|
||||||
var a net.IP
|
|
||||||
if i, ok := ip.(*net.UDPAddr); ok {
|
|
||||||
a = i.IP
|
|
||||||
}
|
|
||||||
if i, ok := ip.(*net.TCPAddr); ok {
|
|
||||||
a = i.IP
|
|
||||||
}
|
|
||||||
if a.To4() != nil {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
|
|
||||||
// ruleData returns the data specified by the variable
|
// ruleData returns the data specified by the variable
|
||||||
func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
|
func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) {
|
||||||
|
if md, ok := metadata.FromContext(ctx); ok {
|
||||||
req := request.Request{W: w, Req: r}
|
if value, ok := md.Value(rule.variable); ok {
|
||||||
switch rule.variable {
|
if v, ok := value.([]byte); ok {
|
||||||
case queryName:
|
return v, nil
|
||||||
//Query name is written as ascii string
|
}
|
||||||
return []byte(req.QName()), nil
|
|
||||||
|
|
||||||
case queryType:
|
|
||||||
return rule.uint16ToWire(req.QType()), nil
|
|
||||||
|
|
||||||
case clientIP:
|
|
||||||
return rule.ipToWire(req.Family(), req.IP())
|
|
||||||
|
|
||||||
case clientPort:
|
|
||||||
return rule.portToWire(req.Port())
|
|
||||||
|
|
||||||
case protocol:
|
|
||||||
// Proto is written as ascii string
|
|
||||||
return []byte(req.Proto()), nil
|
|
||||||
|
|
||||||
case serverIP:
|
|
||||||
ip, _, err := net.SplitHostPort(w.LocalAddr().String())
|
|
||||||
if err != nil {
|
|
||||||
ip = w.RemoteAddr().String()
|
|
||||||
}
|
}
|
||||||
return rule.ipToWire(rule.family(w.RemoteAddr()), ip)
|
} else { // No metadata available means metadata plugin is disabled. Try to get the value directly.
|
||||||
|
return variables.GetValue(rule.variable, w, r)
|
||||||
case serverPort:
|
|
||||||
_, port, err := net.SplitHostPort(w.LocalAddr().String())
|
|
||||||
if err != nil {
|
|
||||||
port = "0"
|
|
||||||
}
|
|
||||||
return rule.portToWire(port)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable)
|
return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite will alter the request EDNS0 local options with specified variables
|
// Rewrite will alter the request EDNS0 local options with specified variables
|
||||||
func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *edns0VariableRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
result := RewriteIgnored
|
result := RewriteIgnored
|
||||||
|
|
||||||
data, err := rule.ruleData(w, r)
|
data, err := rule.ruleData(ctx, w, r)
|
||||||
if err != nil || data == nil {
|
if err != nil || data == nil {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
@ -324,21 +254,6 @@ func (rule *edns0VariableRule) GetResponseRule() ResponseRule {
|
||||||
return ResponseRule{}
|
return ResponseRule{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func isValidVariable(variable string) bool {
|
|
||||||
switch variable {
|
|
||||||
case
|
|
||||||
queryName,
|
|
||||||
queryType,
|
|
||||||
clientIP,
|
|
||||||
clientPort,
|
|
||||||
protocol,
|
|
||||||
serverIP,
|
|
||||||
serverPort:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// ends0SubnetRule is a rewrite rule for EDNS0 subnet options
|
// ends0SubnetRule is a rewrite rule for EDNS0 subnet options
|
||||||
type edns0SubnetRule struct {
|
type edns0SubnetRule struct {
|
||||||
mode string
|
mode string
|
||||||
|
@ -400,7 +315,7 @@ func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite will alter the request EDNS0 subnet option
|
// Rewrite will alter the request EDNS0 subnet option
|
||||||
func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *edns0SubnetRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
result := RewriteIgnored
|
result := RewriteIgnored
|
||||||
o := setupEdns0Opt(r)
|
o := setupEdns0Opt(r)
|
||||||
found := false
|
found := false
|
||||||
|
@ -446,17 +361,6 @@ const (
|
||||||
Append = "append"
|
Append = "append"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Supported local EDNS0 variables
|
|
||||||
const (
|
|
||||||
queryName = "{qname}"
|
|
||||||
queryType = "{qtype}"
|
|
||||||
clientIP = "{client_ip}"
|
|
||||||
clientPort = "{client_port}"
|
|
||||||
protocol = "{protocol}"
|
|
||||||
serverIP = "{server_ip}"
|
|
||||||
serverPort = "{server_port}"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Subnet maximum bit mask length
|
// Subnet maximum bit mask length
|
||||||
const (
|
const (
|
||||||
maxV4BitMaskLen = 32
|
maxV4BitMaskLen = 32
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package rewrite
|
package rewrite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -56,7 +57,7 @@ const (
|
||||||
|
|
||||||
// Rewrite rewrites the current request based upon exact match of the name
|
// Rewrite rewrites the current request based upon exact match of the name
|
||||||
// in the question section of the request
|
// in the question section of the request
|
||||||
func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *nameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
if rule.From == r.Question[0].Name {
|
if rule.From == r.Question[0].Name {
|
||||||
r.Question[0].Name = rule.To
|
r.Question[0].Name = rule.To
|
||||||
return RewriteDone
|
return RewriteDone
|
||||||
|
@ -65,7 +66,7 @@ func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite rewrites the current request when the name begins with the matching string
|
// Rewrite rewrites the current request when the name begins with the matching string
|
||||||
func (rule *prefixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *prefixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
if strings.HasPrefix(r.Question[0].Name, rule.Prefix) {
|
if strings.HasPrefix(r.Question[0].Name, rule.Prefix) {
|
||||||
r.Question[0].Name = rule.Replacement + strings.TrimLeft(r.Question[0].Name, rule.Prefix)
|
r.Question[0].Name = rule.Replacement + strings.TrimLeft(r.Question[0].Name, rule.Prefix)
|
||||||
return RewriteDone
|
return RewriteDone
|
||||||
|
@ -74,7 +75,7 @@ func (rule *prefixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite rewrites the current request when the name ends with the matching string
|
// Rewrite rewrites the current request when the name ends with the matching string
|
||||||
func (rule *suffixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *suffixNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
if strings.HasSuffix(r.Question[0].Name, rule.Suffix) {
|
if strings.HasSuffix(r.Question[0].Name, rule.Suffix) {
|
||||||
r.Question[0].Name = strings.TrimRight(r.Question[0].Name, rule.Suffix) + rule.Replacement
|
r.Question[0].Name = strings.TrimRight(r.Question[0].Name, rule.Suffix) + rule.Replacement
|
||||||
return RewriteDone
|
return RewriteDone
|
||||||
|
@ -84,7 +85,7 @@ func (rule *suffixNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
|
|
||||||
// Rewrite rewrites the current request based upon partial match of the
|
// Rewrite rewrites the current request based upon partial match of the
|
||||||
// name in the question section of the request
|
// name in the question section of the request
|
||||||
func (rule *substringNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *substringNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
if strings.Contains(r.Question[0].Name, rule.Substring) {
|
if strings.Contains(r.Question[0].Name, rule.Substring) {
|
||||||
r.Question[0].Name = strings.Replace(r.Question[0].Name, rule.Substring, rule.Replacement, -1)
|
r.Question[0].Name = strings.Replace(r.Question[0].Name, rule.Substring, rule.Replacement, -1)
|
||||||
return RewriteDone
|
return RewriteDone
|
||||||
|
@ -94,7 +95,7 @@ func (rule *substringNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result
|
||||||
|
|
||||||
// Rewrite rewrites the current request when the name in the question
|
// Rewrite rewrites the current request when the name in the question
|
||||||
// section of the request matches a regular expression
|
// section of the request matches a regular expression
|
||||||
func (rule *regexNameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *regexNameRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
regexGroups := rule.Pattern.FindStringSubmatch(r.Question[0].Name)
|
regexGroups := rule.Pattern.FindStringSubmatch(r.Question[0].Name)
|
||||||
if len(regexGroups) == 0 {
|
if len(regexGroups) == 0 {
|
||||||
return RewriteIgnored
|
return RewriteIgnored
|
||||||
|
|
|
@ -42,7 +42,7 @@ type Rewrite struct {
|
||||||
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
wr := NewResponseReverter(w, r)
|
wr := NewResponseReverter(w, r)
|
||||||
for _, rule := range rw.Rules {
|
for _, rule := range rw.Rules {
|
||||||
switch result := rule.Rewrite(w, r); result {
|
switch result := rule.Rewrite(ctx, w, r); result {
|
||||||
case RewriteDone:
|
case RewriteDone:
|
||||||
respRule := rule.GetResponseRule()
|
respRule := rule.GetResponseRule()
|
||||||
if respRule.Active == true {
|
if respRule.Active == true {
|
||||||
|
@ -76,7 +76,7 @@ func (rw Rewrite) Name() string { return "rewrite" }
|
||||||
// Rule describes a rewrite rule.
|
// Rule describes a rewrite rule.
|
||||||
type Rule interface {
|
type Rule interface {
|
||||||
// Rewrite rewrites the current request.
|
// Rewrite rewrites the current request.
|
||||||
Rewrite(dns.ResponseWriter, *dns.Msg) Result
|
Rewrite(context.Context, dns.ResponseWriter, *dns.Msg) Result
|
||||||
// Mode returns the processing mode stop or continue.
|
// Mode returns the processing mode stop or continue.
|
||||||
Mode() string
|
Mode() string
|
||||||
// GetResponseRule returns the rule to rewrite response with, if any.
|
// GetResponseRule returns the rule to rewrite response with, if any.
|
||||||
|
|
|
@ -71,7 +71,7 @@ func TestNewRule(t *testing.T) {
|
||||||
{[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})},
|
{[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})},
|
||||||
{[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})},
|
{[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})},
|
||||||
{[]string{"edns0", "nsid", "foo"}, true, nil},
|
{[]string{"edns0", "nsid", "foo"}, true, nil},
|
||||||
{[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil},
|
{[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
|
@ -79,7 +79,7 @@ func TestNewRule(t *testing.T) {
|
||||||
{[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil},
|
{[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
|
@ -87,7 +87,7 @@ func TestNewRule(t *testing.T) {
|
||||||
{[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil},
|
{[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
{[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
{[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})},
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
package rewrite
|
package rewrite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -28,7 +29,7 @@ func newTypeRule(nextAction string, args ...string) (Rule, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite rewrites the the current request.
|
// Rewrite rewrites the the current request.
|
||||||
func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result {
|
func (rule *typeRule) Rewrite(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) Result {
|
||||||
if rule.fromType > 0 && rule.toType > 0 {
|
if rule.fromType > 0 && rule.toType > 0 {
|
||||||
if r.Question[0].Qtype == rule.fromType {
|
if r.Question[0].Qtype == rule.fromType {
|
||||||
r.Question[0].Qtype = rule.toType
|
r.Question[0].Qtype = rule.toType
|
||||||
|
|
Loading…
Add table
Reference in a new issue