From 8e5d0a23fa7225382f869a6af918f1916bca352c Mon Sep 17 00:00:00 2001 From: Thong Huynh Date: Fri, 8 Sep 2017 13:36:09 -0700 Subject: [PATCH] Add EDNS0_SUBNET rewrite (#1022) * Add EDNS0_SUBNET rewrite * Fix review comments * Update comment * Fix according to review comments * Add ResponseWriter6 instead of parameterized the existing ResponseWriter --- middleware/rewrite/README.md | 17 +++- middleware/rewrite/edns0.go | 102 +++++++++++++++++++++ middleware/rewrite/rewrite_test.go | 141 +++++++++++++++++++++++++++++ middleware/test/responsewriter.go | 17 ++++ 4 files changed, 276 insertions(+), 1 deletion(-) diff --git a/middleware/rewrite/README.md b/middleware/rewrite/README.md index 8fa38a377..63334d09c 100644 --- a/middleware/rewrite/README.md +++ b/middleware/rewrite/README.md @@ -37,7 +37,7 @@ Using FIELD edns0, you can set, append, or replace specific EDNS0 options on the * `append` will add the option regardless of what options already exist * `set` will modify a matching option or add one if none is found -Currently supported are `EDNS0_LOCAL` and `EDNS0_NSID`. +Currently supported are `EDNS0_LOCAL`, `EDNS0_NSID` and `EDNS0_SUBNET`. ### `EDNS0_LOCAL` @@ -74,3 +74,18 @@ rewrite edns0 local set 0xffee {client_ip} This has no fields; it will add an NSID option with an empty string for the NSID. If the option already exists and the action is `replace` or `set`, then the NSID in the option will be set to the empty string. + +### `EDNS0_SUBNET` + +This has two fields, IPv4 bitmask length and IPv6 bitmask length. The bitmask +length is used to extract the client subnet from the source IP address in the query. + +Example: + +~~~ + rewrite edns0 subnet set 24 56 +~~~ + +* If the query has source IP as IPv4, the first 24 bits in the IP will be the network subnet. +* If the query has source IP as IPv6, the first 56 bits in the IP will be the network subnet. + diff --git a/middleware/rewrite/edns0.go b/middleware/rewrite/edns0.go index 0c983ff02..bdfcac6fd 100644 --- a/middleware/rewrite/edns0.go +++ b/middleware/rewrite/edns0.go @@ -133,6 +133,11 @@ func newEdns0Rule(args ...string) (Rule, error) { return nil, fmt.Errorf("EDNS0 NSID rules do not accept args") } return &edns0NsidRule{action: action}, nil + case "subnet": + if len(args) != 4 { + return nil, fmt.Errorf("EDNS0 subnet rules require exactly three args") + } + return newEdns0SubnetRule(action, args[2], args[3]) default: return nil, fmt.Errorf("invalid rule type %q", ruleType) } @@ -304,6 +309,97 @@ func isValidVariable(variable string) bool { return false } +// ends0SubnetRule is a rewrite rule for EDNS0 subnet options +type edns0SubnetRule struct { + v4BitMaskLen uint8 + v6BitMaskLen uint8 + action string +} + +func newEdns0SubnetRule(action, v4BitMaskLen, v6BitMaskLen string) (*edns0SubnetRule, error) { + v4Len, err := strconv.ParseUint(v4BitMaskLen, 0, 16) + if err != nil { + return nil, err + } + // Validate V4 length + if v4Len > maxV4BitMaskLen { + return nil, fmt.Errorf("invalid IPv4 bit mask length %d", v4Len) + } + + v6Len, err := strconv.ParseUint(v6BitMaskLen, 0, 16) + if err != nil { + return nil, err + } + //Validate V6 length + if v6Len > maxV6BitMaskLen { + return nil, fmt.Errorf("invalid IPv6 bit mask length %d", v6Len) + } + + return &edns0SubnetRule{action: action, + v4BitMaskLen: uint8(v4Len), v6BitMaskLen: uint8(v6Len)}, nil +} + +// fillEcsData sets the subnet data into the ecs option +func (rule *edns0SubnetRule) fillEcsData(w dns.ResponseWriter, r *dns.Msg, + ecs *dns.EDNS0_SUBNET) error { + + req := request.Request{W: w, Req: r} + family := req.Family() + if (family != 1) && (family != 2) { + return fmt.Errorf("unable to fill data for EDNS0 subnet due to invalid IP family") + } + + ecs.DraftOption = false + ecs.Family = uint16(family) + ecs.SourceScope = 0 + + ipAddr := req.IP() + switch family { + case 1: + ipv4Mask := net.CIDRMask(int(rule.v4BitMaskLen), 32) + ipv4Addr := net.ParseIP(ipAddr) + ecs.SourceNetmask = rule.v4BitMaskLen + ecs.Address = ipv4Addr.Mask(ipv4Mask).To4() + case 2: + ipv6Mask := net.CIDRMask(int(rule.v6BitMaskLen), 128) + ipv6Addr := net.ParseIP(ipAddr) + ecs.SourceNetmask = rule.v6BitMaskLen + ecs.Address = ipv6Addr.Mask(ipv6Mask).To16() + } + return nil +} + +// Rewrite will alter the request EDNS0 subnet option +func (rule *edns0SubnetRule) Rewrite(w dns.ResponseWriter, r *dns.Msg) Result { + result := RewriteIgnored + o := setupEdns0Opt(r) + found := false + for _, s := range o.Option { + switch e := s.(type) { + case *dns.EDNS0_SUBNET: + if rule.action == Replace || rule.action == Set { + if rule.fillEcsData(w, r, e) == nil { + result = RewriteDone + } + } + found = true + break + } + } + + // add option if not found + if !found && (rule.action == Append || rule.action == Set) { + o.SetDo() + opt := dns.EDNS0_SUBNET{Code: dns.EDNS0SUBNET} + if rule.fillEcsData(w, r, &opt) == nil { + o.Option = append(o.Option, &opt) + result = RewriteDone + } + } + + return result +} + // These are all defined actions. const ( Replace = "replace" @@ -321,3 +417,9 @@ const ( serverIP = "{server_ip}" serverPort = "{server_port}" ) + +// Subnet maximum bit mask length +const ( + maxV4BitMaskLen = 32 + maxV6BitMaskLen = 128 +) diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go index f6337ab5e..39648711e 100644 --- a/middleware/rewrite/rewrite_test.go +++ b/middleware/rewrite/rewrite_test.go @@ -81,6 +81,13 @@ func TestNewRule(t *testing.T) { {[]string{"edns0", "local", "replace", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, {[]string{"edns0", "local", "replace", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "subnet", "set", "-1", "56"}, true, nil}, + {[]string{"edns0", "subnet", "set", "24", "-56"}, true, nil}, + {[]string{"edns0", "subnet", "set", "33", "56"}, true, nil}, + {[]string{"edns0", "subnet", "set", "24", "129"}, true, nil}, + {[]string{"edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, } for i, tc := range tests { @@ -303,6 +310,30 @@ func optsEqual(a, b []dns.EDNS0) bool { } else { return false } + case *dns.EDNS0_SUBNET: + if bb, ok := b[i].(*dns.EDNS0_SUBNET); ok { + if aa.Code != bb.Code { + return false + } + if aa.Family != bb.Family { + return false + } + if aa.SourceNetmask != bb.SourceNetmask { + return false + } + if aa.SourceScope != bb.SourceScope { + return false + } + if !bytes.Equal(aa.Address, bb.Address) { + return false + } + if aa.DraftOption != bb.DraftOption { + return false + } + } else { + return false + } + default: return false } @@ -389,3 +420,113 @@ func TestRewriteEDNS0LocalVariable(t *testing.T) { } } } + +func TestRewriteEDNS0Subnet(t *testing.T) { + rw := Rewrite{ + Next: middleware.HandlerFunc(msgPrinter), + noRevert: true, + } + + tests := []struct { + writer dns.ResponseWriter + fromOpts []dns.EDNS0 + args []string + toOpts []dns.EDNS0 + }{ + { + &test.ResponseWriter{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x18, + SourceScope: 0x0, + Address: []byte{0x0A, 0xF0, 0x00, 0x00}, + DraftOption: false}}, + }, + { + &test.ResponseWriter{}, + []dns.EDNS0{}, + []string{"subnet", "set", "32", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x20, + SourceScope: 0x0, + Address: []byte{0x0A, 0xF0, 0x00, 0x01}, + DraftOption: false}}, + }, + { + &test.ResponseWriter{}, + []dns.EDNS0{}, + []string{"subnet", "set", "0", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x0, + SourceScope: 0x0, + Address: []byte{0x00, 0x00, 0x00, 0x00}, + DraftOption: false}}, + }, + { + &test.ResponseWriter6{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x2, + SourceNetmask: 0x38, + SourceScope: 0x0, + Address: []byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + DraftOption: false}}, + }, + { + &test.ResponseWriter6{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "128"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x2, + SourceNetmask: 0x80, + SourceScope: 0x0, + Address: []byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x42, 0x00, 0xff, 0xfe, 0xca, 0x4c, 0x65}, + DraftOption: false}}, + }, + { + &test.ResponseWriter6{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "0"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x2, + SourceNetmask: 0x0, + SourceScope: 0x0, + Address: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + DraftOption: false}}, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + + r, err := newEdns0Rule(tc.args...) + if err != nil { + t.Errorf("Error creating test rule: %s", err) + continue + } + rw.Rules = []Rule{r} + rec := dnsrecorder.New(tc.writer) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + if o == nil { + t.Errorf("Test %d: EDNS0 options not set", i) + continue + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +} diff --git a/middleware/test/responsewriter.go b/middleware/test/responsewriter.go index 7aa6dd133..79eaa00f3 100644 --- a/middleware/test/responsewriter.go +++ b/middleware/test/responsewriter.go @@ -42,3 +42,20 @@ func (t *ResponseWriter) TsigTimersOnly(bool) { return } // Hijack implement dns.ResponseWriter interface. func (t *ResponseWriter) Hijack() { return } + +// RepsponseWrite6 returns fixed client and remote address in IPv6. The remote +// address is always fe80::42:ff:feca:4c65 and port 40212. The local address +// is always ::1 and port 53. +type ResponseWriter6 struct { + ResponseWriter +} + +// LocalAddr returns the local address, always ::1, port 53 (UDP). +func (t *ResponseWriter6) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""} +} + +// RemoteAddr returns the remote address, always fe80::42:ff:feca:4c65 port 40212 (UDP). +func (t *ResponseWriter6) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""} +}