diff --git a/middleware/rewrite/README.md b/middleware/rewrite/README.md index 0ab1e1a1d..8fa38a377 100644 --- a/middleware/rewrite/README.md +++ b/middleware/rewrite/README.md @@ -41,8 +41,9 @@ Currently supported are `EDNS0_LOCAL` and `EDNS0_NSID`. ### `EDNS0_LOCAL` -This has two fields, code and data. A match is defined as having the same code. Data may be a string, or if -it starts with `0x` it will be treated as hex. Example: +This has two fields, code and data. A match is defined as having the same code. Data may be a string or a variable. + +* A string data can be treated as hex if it starts with `0x`. Example: ~~~ rewrite edns0 local set 0xffee 0x61626364 @@ -54,6 +55,21 @@ rewrites the first local option with code 0xffee, setting the data to "abcd". Eq rewrite edns0 local set 0xffee abcd ~~~ +* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables: + * {qname} + * {qtype} + * {client_ip} + * {client_port} + * {protocol} + * {server_ip} + * {server_port} + +Example: + +~~~ +rewrite edns0 local set 0xffee {client_ip} +~~~ + ### `EDNS0_NSID` This has no fields; it will add an NSID option with an empty string for the NSID. If the option already exists diff --git a/middleware/rewrite/class.go b/middleware/rewrite/class.go index a2bd00d0f..8cc7d26b7 100644 --- a/middleware/rewrite/class.go +++ b/middleware/rewrite/class.go @@ -24,7 +24,7 @@ func newClassRule(fromS, toS string) (Rule, error) { } // Rewrite rewrites the the current request. -func (rule *classRule) Rewrite(r *dns.Msg) Result { +func (rule *classRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { if rule.fromClass > 0 && rule.toClass > 0 { if r.Question[0].Qclass == rule.fromClass { r.Question[0].Qclass = rule.toClass diff --git a/middleware/rewrite/edns0.go b/middleware/rewrite/edns0.go index db98bead8..156d7c06c 100644 --- a/middleware/rewrite/edns0.go +++ b/middleware/rewrite/edns0.go @@ -2,11 +2,14 @@ package rewrite import ( + "encoding/binary" "encoding/hex" "fmt" + "net" "strconv" "strings" + "github.com/coredns/coredns/request" "github.com/miekg/dns" ) @@ -17,6 +20,13 @@ type edns0LocalRule struct { data []byte } +// edns0VariableRule is a rewrite rule for EDNS0_LOCAL options with variable +type edns0VariableRule struct { + action string + code uint16 + variable string +} + // ends0NsidRule is a rewrite rule for EDNS0_NSID options type edns0NsidRule struct { action string @@ -33,7 +43,7 @@ func setupEdns0Opt(r *dns.Msg) *dns.OPT { } // Rewrite will alter the request EDNS0 NSID option -func (rule *edns0NsidRule) Rewrite(r *dns.Msg) Result { +func (rule *edns0NsidRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -61,7 +71,7 @@ Option: } // Rewrite will alter the request EDNS0 local options -func (rule *edns0LocalRule) Rewrite(r *dns.Msg) Result { +func (rule *edns0LocalRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { result := RewriteIgnored o := setupEdns0Opt(r) found := false @@ -113,6 +123,10 @@ func newEdns0Rule(args ...string) (Rule, error) { if len(args) != 4 { return nil, fmt.Errorf("EDNS0 local rules require exactly three args") } + //Check for variable option + if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") { + return newEdns0VariableRule(action, args[2], args[3]) + } return newEdns0LocalRule(action, args[2], args[3]) case "nsid": if len(args) != 2 { @@ -137,13 +151,173 @@ func newEdns0LocalRule(action, code, data string) (*edns0LocalRule, error) { return nil, err } } - return &edns0LocalRule{action: action, code: uint16(c), data: decoded}, nil } +// newEdns0VariableRule creates an EDNS0 rule that handles variable substitution +func newEdns0VariableRule(action, code, variable string) (*edns0VariableRule, error) { + c, err := strconv.ParseUint(code, 0, 16) + if err != nil { + return nil, err + } + //Validate + if !isValidVariable(variable) { + return nil, fmt.Errorf("unsupported variable name %q", variable) + } + return &edns0VariableRule{action: action, code: uint16(c), variable: variable}, nil +} + +// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6. +func (rule *edns0VariableRule) ipToWire(family int, ipAddr string) ([]byte, error) { + + switch family { + case 1: + return net.ParseIP(ipAddr).To4(), nil + case 2: + return net.ParseIP(ipAddr).To16(), nil + } + return nil, fmt.Errorf("Invalid IP address family (i.e. version) %d", family) +} + +// uint16ToWire writes unit16 to wire/binary format +func (rule *edns0VariableRule) uint16ToWire(data uint16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, uint16(data)) + return buf +} + +// portToWire writes port to wire/binary format, 2 bytes +func (rule *edns0VariableRule) portToWire(portStr string) ([]byte, error) { + + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, err + } + return rule.uint16ToWire(uint16(port)), nil +} + +// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6. +func (rule *edns0VariableRule) family(ip net.Addr) int { + var a net.IP + if i, ok := ip.(*net.UDPAddr); ok { + a = i.IP + } + if i, ok := ip.(*net.TCPAddr); ok { + a = i.IP + } + if a.To4() != nil { + return 1 + } + return 2 +} + +// ruleData returns the data specified by the variable +func (rule *edns0VariableRule) ruleData(w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { + + req := request.Request{W: w, Req: r} + switch rule.variable { + case queryName: + //Query name is written as ascii string + return []byte(req.QName()), nil + + case queryType: + return rule.uint16ToWire(req.QType()), nil + + case clientIP: + return rule.ipToWire(req.Family(), req.IP()) + + case clientPort: + return rule.portToWire(req.Port()) + + case protocol: + // Proto is written as ascii string + return []byte(req.Proto()), nil + + case serverIP: + serverIp, _, err := net.SplitHostPort(w.LocalAddr().String()) + if err != nil { + serverIp = w.RemoteAddr().String() + } + return rule.ipToWire(rule.family(w.RemoteAddr()), serverIp) + + case serverPort: + _, port, err := net.SplitHostPort(w.LocalAddr().String()) + if err != nil { + port = "0" + } + return rule.portToWire(port) + } + + return nil, fmt.Errorf("Unable to extract data for variable %s", rule.variable) +} + +// Rewrite will alter the request EDNS0 local options with specified variables +func (rule *edns0VariableRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { + result := RewriteIgnored + + data, err := rule.ruleData(w, r) + if err != nil || data == nil { + return result + } + + o := setupEdns0Opt(r) + found := false + for _, s := range o.Option { + switch e := s.(type) { + case *dns.EDNS0_LOCAL: + if rule.code == e.Code { + if rule.action == Replace || rule.action == Set { + e.Data = data + result = RewriteDone + } + found = true + break + } + } + } + + // add option if not found + if !found && (rule.action == Append || rule.action == Set) { + o.SetDo() + var opt dns.EDNS0_LOCAL + opt.Code = rule.code + opt.Data = data + o.Option = append(o.Option, &opt) + result = RewriteDone + } + + return result +} + +func isValidVariable(variable string) bool { + switch variable { + case + queryName, + queryType, + clientIP, + clientPort, + protocol, + serverIP, + serverPort: + return true + } + return false +} + // These are all defined actions. const ( Replace = "replace" Set = "set" Append = "append" ) + +// Supported local EDNS0 variables +const ( + queryName = "{qname}" + queryType = "{qtype}" + clientIP = "{client_ip}" + clientPort = "{client_port}" + protocol = "{protocol}" + serverIP = "{server_ip}" + serverPort = "{server_port}" +) diff --git a/middleware/rewrite/name.go b/middleware/rewrite/name.go index 69c387f10..6233197d6 100644 --- a/middleware/rewrite/name.go +++ b/middleware/rewrite/name.go @@ -15,7 +15,7 @@ func newNameRule(from, to string) (Rule, error) { } // Rewrite rewrites the the current request. -func (rule *nameRule) Rewrite(r *dns.Msg) Result { +func (rule *nameRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { if rule.From == r.Question[0].Name { r.Question[0].Name = rule.To return RewriteDone diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index aa27c3a2a..44e8e43c7 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -35,7 +35,7 @@ type Rewrite struct { func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { wr := NewResponseReverter(w, r) for _, rule := range rw.Rules { - switch result := rule.Rewrite(r); result { + switch result := rule.Rewrite(w, r); result { case RewriteDone: if rw.noRevert { return middleware.NextOrFailure(rw.Name(), rw.Next, ctx, w, r) @@ -59,7 +59,7 @@ func (rw Rewrite) Name() string { return "rewrite" } // Rule describes a rewrite rule. type Rule interface { // Rewrite rewrites the current request. - Rewrite(*dns.Msg) Result + Rewrite(dns.ResponseWriter, *dns.Msg) Result } func newRule(args ...string) (Rule, error) { diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go index e4c0afc50..f6337ab5e 100644 --- a/middleware/rewrite/rewrite_test.go +++ b/middleware/rewrite/rewrite_test.go @@ -57,6 +57,30 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})}, {[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})}, {[]string{"edns0", "nsid", "foo"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, } for i, tc := range tests { @@ -285,3 +309,83 @@ func optsEqual(a, b []dns.EDNS0) bool { } return true } + +func TestRewriteEDNS0LocalVariable(t *testing.T) { + rw := Rewrite{ + Next: middleware.HandlerFunc(msgPrinter), + noRevert: true, + } + + // test.ResponseWriter has the following values: + // The remote will always be 10.240.0.1 and port 40212. + // The local address is always 127.0.0.1 and port 53. + + tests := []struct { + fromOpts []dns.EDNS0 + args []string + toOpts []dns.EDNS0 + }{ + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{qname}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("example.com.")}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{qtype}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x00, 0x01}}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{client_ip}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x0A, 0xF0, 0x00, 0x01}}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{client_port}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x9D, 0x14}}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{protocol}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("udp")}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{server_ip}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x7F, 0x00, 0x00, 0x01}}}, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{server_port}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x00, 0x35}}}, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + + r, err := newEdns0Rule(tc.args...) + if err != nil { + t.Errorf("Error creating test rule: %s", err) + continue + } + rw.Rules = []Rule{r} + + rec := dnsrecorder.New(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + if o == nil { + t.Errorf("Test %d: EDNS0 options not set", i) + continue + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +} diff --git a/middleware/rewrite/type.go b/middleware/rewrite/type.go index c2b7866d7..58eedd51e 100644 --- a/middleware/rewrite/type.go +++ b/middleware/rewrite/type.go @@ -26,7 +26,7 @@ func newTypeRule(fromS, toS string) (Rule, error) { } // Rewrite rewrites the the current request. -func (rule *typeRule) Rewrite(r *dns.Msg) Result { +func (rule *typeRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { if rule.fromType > 0 && rule.toType > 0 { if r.Question[0].Qtype == rule.fromType { r.Question[0].Qtype = rule.toType