From d1f17fa7e061d91aa0af7e1fb959a68db899c812 Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Wed, 7 Sep 2016 11:10:16 +0100 Subject: [PATCH] Cleanup: put middleware helper functions in pkgs (#245) Move all (almost all) Go files in middleware into their own packages. This makes for better naming and discoverability. Lot of changes elsewhere to make this change. The middleware.State was renamed to request.Request which is better, but still does not cover all use-cases. It was also moved out middleware because it is used by `dnsserver` as well. A pkg/dnsutil packages was added for shared, handy, dns util functions. All normalize functions are now put in normalize.go --- core/dnsserver/server.go | 8 +- middleware/cache/cache.go | 31 +- middleware/cache/cache_test.go | 3 +- middleware/cache/handler.go | 3 +- middleware/chaos/chaos.go | 3 +- middleware/chaos/chaos_test.go | 5 +- middleware/dnssec/black_lies_test.go | 4 +- middleware/dnssec/cache_test.go | 4 +- middleware/dnssec/dnskey.go | 4 +- middleware/dnssec/dnssec.go | 15 +- middleware/dnssec/dnssec_test.go | 12 +- middleware/dnssec/handler.go | 3 +- middleware/dnssec/handler_test.go | 10 +- middleware/dnssec/responsewriter.go | 4 +- middleware/errors/errors.go | 8 +- middleware/errors/errors_test.go | 5 +- middleware/errors/setup.go | 4 +- middleware/errors/setup_test.go | 6 +- middleware/etcd/cname_test.go | 6 +- middleware/etcd/debug_test.go | 10 +- middleware/etcd/etcd.go | 2 +- middleware/etcd/group_test.go | 6 +- middleware/etcd/handler.go | 5 +- middleware/etcd/lookup.go | 52 ++-- middleware/etcd/multi_test.go | 6 +- middleware/etcd/other_test.go | 6 +- middleware/etcd/proxy_lookup_test.go | 6 +- middleware/etcd/setup.go | 4 +- middleware/etcd/setup_test.go | 8 +- middleware/etcd/stub_handler.go | 4 +- middleware/etcd/stub_test.go | 7 +- middleware/exchange.go | 8 - middleware/file/delegation_test.go | 6 +- middleware/file/dnssec_test.go | 8 +- middleware/file/ent_test.go | 6 +- middleware/file/file.go | 3 +- middleware/file/lookup_test.go | 10 +- middleware/file/notify.go | 5 +- middleware/file/secondary.go | 4 +- middleware/file/secondary_test.go | 8 +- middleware/file/tree/elem.go | 9 +- .../{canonical.go => file/tree/less.go} | 6 +- .../tree/less_test.go} | 4 +- middleware/file/wildcard_test.go | 6 +- middleware/file/xfr.go | 4 +- middleware/file/zone.go | 4 +- middleware/fs_test.go | 19 -- middleware/host.go | 36 --- middleware/kubernetes/handler.go | 11 +- middleware/kubernetes/kubernetes.go | 6 +- middleware/kubernetes/lookup.go | 80 ++--- middleware/kubernetes/setup.go | 2 +- middleware/loadbalance/loadbalance_test.go | 7 +- middleware/log/log.go | 29 +- middleware/log/log_test.go | 4 +- middleware/log/setup.go | 6 +- middleware/log/setup_test.go | 4 +- middleware/metrics/handler.go | 23 +- middleware/name.go | 25 -- middleware/normalize.go | 78 +++++ middleware/pkg/dnsrecorder/recorder.go | 57 ++++ .../{ => pkg/dnsrecorder}/recorder_test.go | 2 +- middleware/pkg/dnsutil/cname.go | 15 + middleware/pkg/dnsutil/reverse.go | 40 +++ middleware/{ => pkg/edns}/edns.go | 10 +- middleware/{ => pkg/edns}/edns_test.go | 10 +- middleware/{ => pkg/rcode}/rcode.go | 4 +- middleware/{ => pkg/replacer}/replacer.go | 38 +-- .../{ => pkg/replacer}/replacer_test.go | 2 +- middleware/{ => pkg/response}/classify.go | 18 +- .../{ => pkg/response}/classify_test.go | 2 +- middleware/{ => pkg/roller}/roller.go | 4 +- .../pkg/singleflight}/singleflight.go | 0 .../pkg/singleflight}/singleflight_test.go | 0 middleware/{ => pkg/storage}/fs.go | 22 +- middleware/pkg/storage/fs_test.go | 42 +++ middleware/proxy/lookup.go | 14 +- middleware/proxy/reverseproxy.go | 8 +- middleware/recorder.go | 72 ----- middleware/rewrite/condition.go | 7 +- middleware/rewrite/rewrite.go | 140 --------- middleware/rewrite/rewrite_test.go | 5 +- middleware/state_test.go | 289 ------------------ middleware/zone.go | 27 -- middleware/state.go => request/request.go | 126 ++++---- request/request_test.go | 55 ++++ test/etcd_test.go | 4 +- test/proxy_test.go | 4 +- test/server_test.go | 13 +- test/tests.go | 12 +- 90 files changed, 680 insertions(+), 1037 deletions(-) delete mode 100644 middleware/exchange.go rename middleware/{canonical.go => file/tree/less.go} (91%) rename middleware/{canonical_test.go => file/tree/less_test.go} (96%) delete mode 100644 middleware/fs_test.go delete mode 100644 middleware/host.go delete mode 100644 middleware/name.go create mode 100644 middleware/normalize.go create mode 100644 middleware/pkg/dnsrecorder/recorder.go rename middleware/{ => pkg/dnsrecorder}/recorder_test.go (97%) create mode 100644 middleware/pkg/dnsutil/cname.go create mode 100644 middleware/pkg/dnsutil/reverse.go rename middleware/{ => pkg/edns}/edns.go (76%) rename middleware/{ => pkg/edns}/edns_test.go (75%) rename middleware/{ => pkg/rcode}/rcode.go (72%) rename middleware/{ => pkg/replacer}/replacer.go (76%) rename middleware/{ => pkg/replacer}/replacer_test.go (99%) rename middleware/{ => pkg/response}/classify.go (62%) rename middleware/{ => pkg/response}/classify_test.go (97%) rename middleware/{ => pkg/roller}/roller.go (94%) rename {singleflight => middleware/pkg/singleflight}/singleflight.go (100%) rename {singleflight => middleware/pkg/singleflight}/singleflight_test.go (100%) rename middleware/{ => pkg/storage}/fs.go (64%) create mode 100644 middleware/pkg/storage/fs_test.go delete mode 100644 middleware/recorder.go delete mode 100644 middleware/state_test.go delete mode 100644 middleware/zone.go rename middleware/state.go => request/request.go (58%) create mode 100644 request/request_test.go diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go index c29406f79..77a28f477 100644 --- a/core/dnsserver/server.go +++ b/core/dnsserver/server.go @@ -7,7 +7,8 @@ import ( "sync" "time" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/edns" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -163,7 +164,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } }() - if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once. + if m, err := edns.Version(r); err != nil { // Wrong EDNS version, return at once. w.WriteMsg(m) return } @@ -214,10 +215,11 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // DefaultErrorFunc responds to an DNS request with an error. func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} answer := new(dns.Msg) answer.SetRcode(r, rcode) + state.SizeAndDo(answer) w.WriteMsg(answer) diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index fd804d445..bd29a815c 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -6,6 +6,7 @@ import ( "time" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/response" "github.com/miekg/dns" gcache "github.com/patrickmn/go-cache" @@ -23,7 +24,7 @@ func NewCache(ttl int, zones []string, next middleware.Handler) Cache { return Cache{Next: next, Zones: zones, cache: gcache.New(defaultDuration, purgeDuration), cap: time.Duration(ttl) * time.Second} } -func cacheKey(m *dns.Msg, t middleware.MsgType, do bool) string { +func cacheKey(m *dns.Msg, t response.Type, do bool) string { if m.Truncated { return "" } @@ -31,15 +32,15 @@ func cacheKey(m *dns.Msg, t middleware.MsgType, do bool) string { qtype := m.Question[0].Qtype qname := strings.ToLower(m.Question[0].Name) switch t { - case middleware.Success: + case response.Success: fallthrough - case middleware.Delegation: + case response.Delegation: return successKey(qname, qtype, do) - case middleware.NameError: + case response.NameError: return nameErrorKey(qname, do) - case middleware.NoData: + case response.NoData: return noDataKey(qname, qtype, do) - case middleware.OtherError: + case response.OtherError: return "" } return "" @@ -57,7 +58,7 @@ func NewCachingResponseWriter(w dns.ResponseWriter, cache *gcache.Cache, cap tim func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error { do := false - mt, opt := middleware.Classify(res) + mt, opt := response.Classify(res) if opt != nil { do = opt.Do() } @@ -72,7 +73,7 @@ func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error { return c.ResponseWriter.WriteMsg(res) } -func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt middleware.MsgType) { +func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt response.Type) { if key == "" { log.Printf("[ERROR] Caching called with empty cache key") return @@ -80,21 +81,21 @@ func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt middleware.MsgTyp duration := c.cap switch mt { - case middleware.Success, middleware.Delegation: + case response.Success, response.Delegation: if c.cap == 0 { duration = minTtl(m.Answer, mt) } i := newItem(m, duration) c.cache.Set(key, i, duration) - case middleware.NameError, middleware.NoData: + case response.NameError, response.NoData: if c.cap == 0 { duration = minTtl(m.Ns, mt) } i := newItem(m, duration) c.cache.Set(key, i, duration) - case middleware.OtherError: + case response.OtherError: // don't cache these default: log.Printf("[WARNING] Caching called with unknown middleware MsgType: %d", mt) @@ -112,19 +113,19 @@ func (c *CachingResponseWriter) Hijack() { return } -func minTtl(rrs []dns.RR, mt middleware.MsgType) time.Duration { - if mt != middleware.Success && mt != middleware.NameError && mt != middleware.NoData { +func minTtl(rrs []dns.RR, mt response.Type) time.Duration { + if mt != response.Success && mt != response.NameError && mt != response.NoData { return 0 } minTtl := maxTtl for _, r := range rrs { switch mt { - case middleware.NameError, middleware.NoData: + case response.NameError, response.NoData: if r.Header().Rrtype == dns.TypeSOA { return time.Duration(r.(*dns.SOA).Minttl) * time.Second } - case middleware.Success, middleware.Delegation: + case response.Success, response.Delegation: if r.Header().Ttl < minTtl { minTtl = r.Header().Ttl } diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 452831082..dc44fa8a3 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/response" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -78,7 +79,7 @@ func TestCache(t *testing.T) { m = cacheMsg(m, tc) do := tc.in.Do - mt, _ := middleware.Classify(m) + mt, _ := response.Classify(m) key := cacheKey(m, mt, do) crr.set(m, key, mt) diff --git a/middleware/cache/handler.go b/middleware/cache/handler.go index b891d7278..045c8ab1d 100644 --- a/middleware/cache/handler.go +++ b/middleware/cache/handler.go @@ -2,6 +2,7 @@ package cache import ( "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" @@ -10,7 +11,7 @@ import ( // ServeDNS implements the middleware.Handler interface. func (c Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} qname := state.Name() qtype := state.QType() diff --git a/middleware/chaos/chaos.go b/middleware/chaos/chaos.go index 506298de4..8b6df3e84 100644 --- a/middleware/chaos/chaos.go +++ b/middleware/chaos/chaos.go @@ -4,6 +4,7 @@ import ( "os" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -18,7 +19,7 @@ type Chaos struct { } func (c Chaos) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassCHAOS || state.QType() != dns.TypeTXT { return c.Next.ServeDNS(ctx, w, r) } diff --git a/middleware/chaos/chaos_test.go b/middleware/chaos/chaos_test.go index f669eaa7c..6a3261754 100644 --- a/middleware/chaos/chaos_test.go +++ b/middleware/chaos/chaos_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -58,7 +59,7 @@ func TestChaos(t *testing.T) { req.Question[0].Qclass = dns.ClassCHAOS em.Next = tc.next - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) code, err := em.ServeDNS(ctx, rec, req) if err != tc.expectedErr { @@ -68,7 +69,7 @@ func TestChaos(t *testing.T) { t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) } if tc.expectedReply != "" { - answer := rec.Msg().Answer[0].(*dns.TXT).Txt[0] + answer := rec.Msg.Answer[0].(*dns.TXT).Txt[0] if answer != tc.expectedReply { t.Errorf("Test %d: Expected answer %s, but got %s", i, tc.expectedReply, answer) } diff --git a/middleware/dnssec/black_lies_test.go b/middleware/dnssec/black_lies_test.go index 951e8952e..092c36f56 100644 --- a/middleware/dnssec/black_lies_test.go +++ b/middleware/dnssec/black_lies_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -16,7 +16,7 @@ func TestZoneSigningBlackLies(t *testing.T) { defer rm2() m := testNxdomainMsg() - state := middleware.State{Req: m} + state := request.Request{Req: m} m = d.Sign(state, "miek.nl.", time.Now().UTC()) if !section(m.Ns, 2) { t.Errorf("authority section should have 2 sig") diff --git a/middleware/dnssec/cache_test.go b/middleware/dnssec/cache_test.go index 3062f99b0..0f069b6a8 100644 --- a/middleware/dnssec/cache_test.go +++ b/middleware/dnssec/cache_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" ) func TestCacheSet(t *testing.T) { @@ -20,7 +20,7 @@ func TestCacheSet(t *testing.T) { } m := testMsg() - state := middleware.State{Req: m} + state := request.Request{Req: m} k := key(m.Answer) // calculate *before* we add the sig d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil) m = d.Sign(state, "miek.nl.", time.Now().UTC()) diff --git a/middleware/dnssec/dnskey.go b/middleware/dnssec/dnskey.go index 9ae437c54..af345f906 100644 --- a/middleware/dnssec/dnskey.go +++ b/middleware/dnssec/dnskey.go @@ -8,7 +8,7 @@ import ( "os" "time" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -50,7 +50,7 @@ func ParseKeyFile(pubFile, privFile string) (*DNSKEY, error) { } // getDNSKEY returns the correct DNSKEY to the client. Signatures are added when do is true. -func (d Dnssec) getDNSKEY(state middleware.State, zone string, do bool) *dns.Msg { +func (d Dnssec) getDNSKEY(state request.Request, zone string, do bool) *dns.Msg { keys := make([]dns.RR, len(d.keys)) for i, k := range d.keys { keys[i] = dns.Copy(k.K) diff --git a/middleware/dnssec/dnssec.go b/middleware/dnssec/dnssec.go index f517bfe2c..ea914c0ee 100644 --- a/middleware/dnssec/dnssec.go +++ b/middleware/dnssec/dnssec.go @@ -4,7 +4,9 @@ import ( "time" "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/singleflight" + "github.com/miekg/coredns/middleware/pkg/response" + "github.com/miekg/coredns/middleware/pkg/singleflight" + "github.com/miekg/coredns/request" "github.com/miekg/dns" gcache "github.com/patrickmn/go-cache" @@ -28,20 +30,21 @@ func New(zones []string, keys []*DNSKEY, next middleware.Handler) Dnssec { } } -// Sign signs the message m. it takes care of negative or nodata responses. It +// Sign signs the message in state. it takes care of negative or nodata responses. It // uses NSEC black lies for authenticated denial of existence. Signatures // creates will be cached for a short while. By default we sign for 8 days, // starting 3 hours ago. -func (d Dnssec) Sign(state middleware.State, zone string, now time.Time) *dns.Msg { +func (d Dnssec) Sign(state request.Request, zone string, now time.Time) *dns.Msg { req := state.Req - mt, _ := middleware.Classify(req) // TODO(miek): need opt record here? - if mt == middleware.Delegation { + + mt, _ := response.Classify(req) // TODO(miek): need opt record here? + if mt == response.Delegation { return req } incep, expir := incepExpir(now) - if mt == middleware.NameError { + if mt == response.NameError { if req.Ns[0].Header().Rrtype != dns.TypeSOA || len(req.Ns) > 1 { return req } diff --git a/middleware/dnssec/dnssec_test.go b/middleware/dnssec/dnssec_test.go index 10f731325..48f13e935 100644 --- a/middleware/dnssec/dnssec_test.go +++ b/middleware/dnssec/dnssec_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -16,7 +16,7 @@ func TestZoneSigning(t *testing.T) { defer rm2() m := testMsg() - state := middleware.State{Req: m} + state := request.Request{Req: m} m = d.Sign(state, "miek.nl.", time.Now().UTC()) if !section(m.Answer, 1) { @@ -44,7 +44,7 @@ func TestZoneSigningDouble(t *testing.T) { d.keys = append(d.keys, key1) m := testMsg() - state := middleware.State{Req: m} + state := request.Request{Req: m} m = d.Sign(state, "miek.nl.", time.Now().UTC()) if !section(m.Answer, 2) { t.Errorf("answer section should have 1 sig") @@ -68,7 +68,7 @@ func TestSigningDifferentZone(t *testing.T) { } m := testMsgEx() - state := middleware.State{Req: m} + state := request.Request{Req: m} d := New([]string{"example.org."}, []*DNSKEY{key}, nil) m = d.Sign(state, "example.org.", time.Now().UTC()) if !section(m.Answer, 1) { @@ -86,7 +86,7 @@ func TestSigningCname(t *testing.T) { defer rm2() m := testMsgCname() - state := middleware.State{Req: m} + state := request.Request{Req: m} m = d.Sign(state, "miek.nl.", time.Now().UTC()) if !section(m.Answer, 1) { t.Errorf("answer section should have 1 sig") @@ -100,7 +100,7 @@ func TestZoneSigningDelegation(t *testing.T) { defer rm2() m := testDelegationMsg() - state := middleware.State{Req: m} + state := request.Request{Req: m} m = d.Sign(state, "miek.nl.", time.Now().UTC()) if !section(m.Ns, 0) { t.Errorf("authority section should have 0 sig") diff --git a/middleware/dnssec/handler.go b/middleware/dnssec/handler.go index 5daf3d322..609150028 100644 --- a/middleware/dnssec/handler.go +++ b/middleware/dnssec/handler.go @@ -2,6 +2,7 @@ package dnssec import ( "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" @@ -10,7 +11,7 @@ import ( // ServeDNS implements the middleware.Handler interface. func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} do := state.Do() qname := state.Name() diff --git a/middleware/dnssec/handler_test.go b/middleware/dnssec/handler_test.go index f7cb7e680..a490c3744 100644 --- a/middleware/dnssec/handler_test.go +++ b/middleware/dnssec/handler_test.go @@ -5,8 +5,8 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/file" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -83,14 +83,14 @@ func TestLookupZone(t *testing.T) { for _, tc := range dnsTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := dh.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) @@ -121,14 +121,14 @@ func TestLookupDNSKEY(t *testing.T) { for _, tc := range dnssecTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := dh.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg if !resp.Authoritative { t.Errorf("Authoritative Answer should be true, got false") } diff --git a/middleware/dnssec/responsewriter.go b/middleware/dnssec/responsewriter.go index 2a7cbb972..0032fa7ba 100644 --- a/middleware/dnssec/responsewriter.go +++ b/middleware/dnssec/responsewriter.go @@ -5,6 +5,8 @@ import ( "time" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" + "github.com/miekg/dns" ) @@ -20,7 +22,7 @@ func NewDnssecResponseWriter(w dns.ResponseWriter, d Dnssec) *DnssecResponseWrit func (d *DnssecResponseWriter) WriteMsg(res *dns.Msg) error { // By definition we should sign anything that comes back, we should still figure out for // which zone it should be. - state := middleware.State{W: d.ResponseWriter, Req: res} + state := request.Request{W: d.ResponseWriter, Req: res} qname := state.Name() zone := middleware.Zones(d.d.zones).Matches(qname) diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go index 8936b8945..ef178d1a3 100644 --- a/middleware/errors/errors.go +++ b/middleware/errors/errors.go @@ -9,6 +9,8 @@ import ( "time" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/roller" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -19,7 +21,7 @@ type ErrorHandler struct { Next middleware.Handler LogFile string Log *log.Logger - LogRoller *middleware.LogRoller + LogRoller *roller.LogRoller Debug bool // if true, errors are written out to client rather than to a log } @@ -29,7 +31,7 @@ func (h ErrorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns rcode, err := h.Next.ServeDNS(ctx, w, r) if err != nil { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, state.Name(), state.Type(), err) if h.Debug { @@ -53,7 +55,7 @@ func (h ErrorHandler) recovery(ctx context.Context, w dns.ResponseWriter, r *dns return } - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} // Obtain source of panic // From: https://gist.github.com/swdunlop/9629168 var name, file string // function name, file name diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go index 6885aae0f..4c8f56fa2 100644 --- a/middleware/errors/errors_test.go +++ b/middleware/errors/errors_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -47,7 +48,7 @@ func TestErrors(t *testing.T) { for i, tc := range tests { em.Next = tc.next buf.Reset() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) code, err := em.ServeDNS(ctx, rec, req) if err != tc.expectedErr { @@ -78,7 +79,7 @@ func TestVisibleErrorWithPanic(t *testing.T) { req := new(dns.Msg) req.SetQuestion("example.org.", dns.TypeA) - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) code, err := eh.ServeDNS(ctx, rec, req) if code != 0 { diff --git a/middleware/errors/setup.go b/middleware/errors/setup.go index e1c77373d..5c7c1016c 100644 --- a/middleware/errors/setup.go +++ b/middleware/errors/setup.go @@ -6,7 +6,7 @@ import ( "os" "github.com/miekg/coredns/core/dnsserver" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/roller" "github.com/hashicorp/go-syslog" "github.com/mholt/caddy" @@ -93,7 +93,7 @@ func errorsParse(c *caddy.Controller) (ErrorHandler, error) { if c.NextArg() { if c.Val() == "{" { c.IncrNest() - logRoller, err := middleware.ParseRoller(c) + logRoller, err := roller.Parse(c) if err != nil { return hadBlock, err } diff --git a/middleware/errors/setup_test.go b/middleware/errors/setup_test.go index 6e5a85d08..c1dbf7267 100644 --- a/middleware/errors/setup_test.go +++ b/middleware/errors/setup_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/mholt/caddy" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/roller" ) func TestErrorsParse(t *testing.T) { @@ -29,7 +29,7 @@ func TestErrorsParse(t *testing.T) { }}, {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, ErrorHandler{ LogFile: "errors.txt", - LogRoller: &middleware.LogRoller{ + LogRoller: &roller.LogRoller{ MaxSize: 2, MaxAge: 10, MaxBackups: 3, @@ -43,7 +43,7 @@ func TestErrorsParse(t *testing.T) { } }`, false, ErrorHandler{ LogFile: "errors.txt", - LogRoller: &middleware.LogRoller{ + LogRoller: &roller.LogRoller{ MaxSize: 3, MaxAge: 11, MaxBackups: 5, diff --git a/middleware/etcd/cname_test.go b/middleware/etcd/cname_test.go index ee341b7b6..fdd0f50e5 100644 --- a/middleware/etcd/cname_test.go +++ b/middleware/etcd/cname_test.go @@ -7,8 +7,8 @@ package etcd import ( "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -25,14 +25,14 @@ func TestCnameLookup(t *testing.T) { for _, tc := range dnsTestCasesCname { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg if !test.Header(t, tc, resp) { t.Logf("%v\n", resp) continue diff --git a/middleware/etcd/debug_test.go b/middleware/etcd/debug_test.go index 82de9fe1f..930ceb8ce 100644 --- a/middleware/etcd/debug_test.go +++ b/middleware/etcd/debug_test.go @@ -7,8 +7,8 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -41,14 +41,14 @@ func TestDebugLookup(t *testing.T) { for _, tc := range dnsTestCasesDebug { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) continue } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) @@ -79,14 +79,14 @@ func TestDebugLookupFalse(t *testing.T) { for _, tc := range dnsTestCasesDebugFalse { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) continue } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/etcd.go b/middleware/etcd/etcd.go index 87a82e8cb..4b5b424f4 100644 --- a/middleware/etcd/etcd.go +++ b/middleware/etcd/etcd.go @@ -9,8 +9,8 @@ import ( "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/singleflight" "github.com/miekg/coredns/middleware/proxy" - "github.com/miekg/coredns/singleflight" etcdc "github.com/coreos/etcd/client" "golang.org/x/net/context" diff --git a/middleware/etcd/group_test.go b/middleware/etcd/group_test.go index 7a2808d45..abf777982 100644 --- a/middleware/etcd/group_test.go +++ b/middleware/etcd/group_test.go @@ -6,8 +6,8 @@ import ( "sort" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -23,14 +23,14 @@ func TestGroupLookup(t *testing.T) { for _, tc := range dnsTestCasesGroup { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) continue } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/handler.go b/middleware/etcd/handler.go index 132dba370..40f9523f0 100644 --- a/middleware/etcd/handler.go +++ b/middleware/etcd/handler.go @@ -5,6 +5,7 @@ import ( "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -12,7 +13,7 @@ import ( func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { opt := Options{} - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassINET { return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET") } @@ -115,7 +116,7 @@ func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( } // Err write an error response to the client. -func (e *Etcd) Err(zone string, rcode int, state middleware.State, debug []msg.Service, err error, opt Options) (int, error) { +func (e *Etcd) Err(zone string, rcode int, state request.Request, debug []msg.Service, err error, opt Options) (int, error) { m := new(dns.Msg) m.SetRcode(state.Req, rcode) m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true diff --git a/middleware/etcd/lookup.go b/middleware/etcd/lookup.go index 488ca4f26..b13c257e6 100644 --- a/middleware/etcd/lookup.go +++ b/middleware/etcd/lookup.go @@ -8,6 +8,8 @@ import ( "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsutil" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -16,7 +18,7 @@ type Options struct { Debug string } -func (e Etcd) records(state middleware.State, exact bool, opt Options) (services, debug []msg.Service, err error) { +func (e Etcd) records(state request.Request, exact bool, opt Options) (services, debug []msg.Service, err error) { services, err = e.Records(state.Name(), exact) if err != nil { return @@ -28,7 +30,7 @@ func (e Etcd) records(state middleware.State, exact bool, opt Options) (services return } -func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) { +func (e Etcd) A(zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, false, opt) if err != nil { return nil, debug, err @@ -49,11 +51,11 @@ func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, o // don't add it, and just continue continue } - if isDuplicateCNAME(newRecord, previousRecords) { + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { continue } - state1 := copyState(state, serv.Host, state.QType()) + state1 := state.NewWithQuestion(serv.Host, state.QType()) nextRecords, nextDebug, err := e.A(zone, state1, append(previousRecords, newRecord), opt) if err == nil { @@ -90,7 +92,7 @@ func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, o return records, debug, nil } -func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) { +func (e Etcd) AAAA(zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, false, opt) if err != nil { return nil, debug, err @@ -111,11 +113,11 @@ func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR // don't add it, and just continue continue } - if isDuplicateCNAME(newRecord, previousRecords) { + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { continue } - state1 := copyState(state, serv.Host, state.QType()) + state1 := state.NewWithQuestion(serv.Host, state.QType()) nextRecords, nextDebug, err := e.AAAA(zone, state1, append(previousRecords, newRecord), opt) if err == nil { @@ -155,7 +157,7 @@ func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR // SRV returns SRV records from etcd. // If the Target is not a name but an IP address, a name is created on the fly. -func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { +func (e Etcd) SRV(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, false, opt) if err != nil { return nil, nil, nil, err @@ -220,7 +222,7 @@ func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, ex } // Internal name, we should have some info on them, either v4 or v6 // Clients expect a complete answer, because we are a recursor in their view. - state1 := copyState(state, srv.Target, dns.TypeA) + state1 := state.NewWithQuestion(srv.Target, dns.TypeA) addr, debugAddr, e1 := e.A(zone, state1, nil, opt) if e1 == nil { extra = append(extra, addr...) @@ -246,7 +248,7 @@ func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, ex // MX returns MX records from etcd. // If the Target is not a name but an IP address, a name is created on the fly. -func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { +func (e Etcd) MX(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, false, opt) if err != nil { return nil, nil, debug, err @@ -291,7 +293,7 @@ func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, ext break } // Internal name - state1 := copyState(state, mx.Mx, dns.TypeA) + state1 := state.NewWithQuestion(mx.Mx, dns.TypeA) addr, debugAddr, e1 := e.A(zone, state1, nil, opt) if e1 == nil { extra = append(extra, addr...) @@ -311,7 +313,7 @@ func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, ext return records, extra, debug, nil } -func (e Etcd) CNAME(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) { +func (e Etcd) CNAME(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, true, opt) if err != nil { return nil, debug, err @@ -327,7 +329,7 @@ func (e Etcd) CNAME(zone string, state middleware.State, opt Options) (records [ } // PTR returns the PTR records, only services that have a domain name as host are included. -func (e Etcd) PTR(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) { +func (e Etcd) PTR(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, true, opt) if err != nil { return nil, debug, err @@ -341,7 +343,7 @@ func (e Etcd) PTR(zone string, state middleware.State, opt Options) (records []d return records, debug, nil } -func (e Etcd) TXT(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) { +func (e Etcd) TXT(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, false, opt) if err != nil { return nil, debug, err @@ -356,7 +358,7 @@ func (e Etcd) TXT(zone string, state middleware.State, opt Options) (records []d return records, debug, nil } -func (e Etcd) NS(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { +func (e Etcd) NS(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { // NS record for this zone live in a special place, ns.dns.. Fake our lookup. // only a tad bit fishy... old := state.QName() @@ -389,7 +391,7 @@ func (e Etcd) NS(zone string, state middleware.State, opt Options) (records, ext } // SOA Record returns a SOA record. -func (e Etcd) SOA(zone string, state middleware.State, opt Options) ([]dns.RR, []msg.Service, error) { +func (e Etcd) SOA(zone string, state request.Request, opt Options) ([]dns.RR, []msg.Service, error) { header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: 300, Class: dns.ClassINET} soa := &dns.SOA{Hdr: header, @@ -404,21 +406,3 @@ func (e Etcd) SOA(zone string, state middleware.State, opt Options) ([]dns.RR, [ // TODO(miek): fake some msg.Service here when returning. return []dns.RR{soa}, nil, nil } - -func isDuplicateCNAME(r *dns.CNAME, records []dns.RR) bool { - for _, rec := range records { - if v, ok := rec.(*dns.CNAME); ok { - if v.Target == r.Target { - return true - } - } - } - return false -} - -// TODO(miek): Move to middleware? -func copyState(state middleware.State, target string, typ uint16) middleware.State { - state1 := middleware.State{W: state.W, Req: state.Req.Copy()} - state1.Req.Question[0] = dns.Question{Name: dns.Fqdn(target), Qclass: dns.ClassINET, Qtype: typ} - return state1 -} diff --git a/middleware/etcd/multi_test.go b/middleware/etcd/multi_test.go index f4b59f50b..d4e3493a7 100644 --- a/middleware/etcd/multi_test.go +++ b/middleware/etcd/multi_test.go @@ -6,8 +6,8 @@ import ( "sort" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -25,14 +25,14 @@ func TestMultiLookup(t *testing.T) { for _, tc := range dnsTestCasesMulti { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/other_test.go b/middleware/etcd/other_test.go index ff37d27d2..2ed622fe0 100644 --- a/middleware/etcd/other_test.go +++ b/middleware/etcd/other_test.go @@ -10,8 +10,8 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -27,14 +27,14 @@ func TestOtherLookup(t *testing.T) { for _, tc := range dnsTestCasesOther { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) continue } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/proxy_lookup_test.go b/middleware/etcd/proxy_lookup_test.go index 5e0999fb0..8b4697e25 100644 --- a/middleware/etcd/proxy_lookup_test.go +++ b/middleware/etcd/proxy_lookup_test.go @@ -6,8 +6,8 @@ import ( "sort" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/test" @@ -27,14 +27,14 @@ func TestProxyLookupFailDebug(t *testing.T) { for _, tc := range dnsTestCasesProxy { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) continue } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/setup.go b/middleware/etcd/setup.go index dc1dddb0e..58ca8286f 100644 --- a/middleware/etcd/setup.go +++ b/middleware/etcd/setup.go @@ -10,8 +10,8 @@ import ( "github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/singleflight" "github.com/miekg/coredns/middleware/proxy" - "github.com/miekg/coredns/singleflight" etcdc "github.com/coreos/etcd/client" "github.com/mholt/caddy" @@ -70,7 +70,7 @@ func etcdParse(c *caddy.Controller) (*Etcd, bool, error) { etc.Zones = make([]string, len(c.ServerBlockKeys)) copy(etc.Zones, c.ServerBlockKeys) } - middleware.Zones(etc.Zones).FullyQualify() + middleware.Zones(etc.Zones).Normalize() if c.NextBlock() { // TODO(miek): 2 switches? switch c.Val() { diff --git a/middleware/etcd/setup_test.go b/middleware/etcd/setup_test.go index b522345d2..39ad3d3a1 100644 --- a/middleware/etcd/setup_test.go +++ b/middleware/etcd/setup_test.go @@ -8,11 +8,11 @@ import ( "testing" "time" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" + "github.com/miekg/coredns/middleware/pkg/singleflight" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/test" - "github.com/miekg/coredns/singleflight" etcdc "github.com/coreos/etcd/client" "github.com/miekg/dns" @@ -65,14 +65,14 @@ func TestLookup(t *testing.T) { for _, tc := range dnsTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got: %v for %s %s\n", err, m.Question[0].Name, dns.Type(m.Question[0].Qtype)) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/stub_handler.go b/middleware/etcd/stub_handler.go index 9d8778219..10a3d2198 100644 --- a/middleware/etcd/stub_handler.go +++ b/middleware/etcd/stub_handler.go @@ -4,7 +4,7 @@ import ( "errors" "log" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -27,7 +27,7 @@ func (s Stub) ServeDNS(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) return dns.RcodeServerFailure, nil } - state := middleware.State{W: w, Req: req} + state := request.Request{W: w, Req: req} m, e := proxy.Forward(state) if e != nil { return dns.RcodeServerFailure, e diff --git a/middleware/etcd/stub_test.go b/middleware/etcd/stub_test.go index b5a101dad..72a8fb7a0 100644 --- a/middleware/etcd/stub_test.go +++ b/middleware/etcd/stub_test.go @@ -8,8 +8,8 @@ import ( "strconv" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -53,7 +53,7 @@ func TestStubLookup(t *testing.T) { for _, tc := range dnsTestCasesStub { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil && m.Question[0].Name == "example.org." { // This is OK, we expect this backend to *not* work. @@ -62,12 +62,11 @@ func TestStubLookup(t *testing.T) { if err != nil { t.Errorf("expected no error, got %v for %s\n", err, m.Question[0].Name) } - resp := rec.Msg() + resp := rec.Msg if resp == nil { // etcd not running? continue } - sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/exchange.go b/middleware/exchange.go deleted file mode 100644 index 783d06e26..000000000 --- a/middleware/exchange.go +++ /dev/null @@ -1,8 +0,0 @@ -package middleware - -import "github.com/miekg/dns" - -func Exchange(c *dns.Client, m *dns.Msg, server string) (*dns.Msg, error) { - r, _, err := c.Exchange(m, server) - return r, err -} diff --git a/middleware/file/delegation_test.go b/middleware/file/delegation_test.go index 727febade..30a6e430c 100644 --- a/middleware/file/delegation_test.go +++ b/middleware/file/delegation_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -57,14 +57,14 @@ func TestLookupDelegation(t *testing.T) { for _, tc := range delegationTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/file/dnssec_test.go b/middleware/file/dnssec_test.go index 2d76447b3..434ebed8f 100644 --- a/middleware/file/dnssec_test.go +++ b/middleware/file/dnssec_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -116,14 +116,14 @@ func TestLookupDNSSEC(t *testing.T) { for _, tc := range dnssecTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) @@ -154,7 +154,7 @@ func BenchmarkLookupDNSSEC(b *testing.B) { fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} ctx := context.TODO() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) tc := test.Case{ Qname: "b.miek.nl.", Qtype: dns.TypeA, Do: true, diff --git a/middleware/file/ent_test.go b/middleware/file/ent_test.go index 735ec67fe..324c6cfb6 100644 --- a/middleware/file/ent_test.go +++ b/middleware/file/ent_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -43,14 +43,14 @@ func TestLookupENT(t *testing.T) { for _, tc := range entTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/file/file.go b/middleware/file/file.go index a99a64d6f..d2ced17e4 100644 --- a/middleware/file/file.go +++ b/middleware/file/file.go @@ -6,6 +6,7 @@ import ( "log" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -24,7 +25,7 @@ type ( ) func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassINET { return dns.RcodeServerFailure, errors.New("can only deal with ClassINET") diff --git a/middleware/file/lookup_test.go b/middleware/file/lookup_test.go index 90f525e3c..d8efd6ea6 100644 --- a/middleware/file/lookup_test.go +++ b/middleware/file/lookup_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -87,14 +87,14 @@ func TestLookup(t *testing.T) { for _, tc := range dnsTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) @@ -122,7 +122,7 @@ func TestLookupNil(t *testing.T) { ctx := context.TODO() m := dnsTestCases[0].Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) fm.ServeDNS(ctx, rec, m) } @@ -134,7 +134,7 @@ func BenchmarkLookup(b *testing.B) { fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} ctx := context.TODO() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) tc := test.Case{ Qname: "www.miek.nl.", Qtype: dns.TypeA, diff --git a/middleware/file/notify.go b/middleware/file/notify.go index 8a2581b84..3c6095a3b 100644 --- a/middleware/file/notify.go +++ b/middleware/file/notify.go @@ -5,6 +5,7 @@ import ( "log" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -12,7 +13,7 @@ import ( // isNotify checks if state is a notify message and if so, will *also* check if it // is from one of the configured masters. If not it will not be a valid notify // message. If the zone z is not a secondary zone the message will also be ignored. -func (z *Zone) isNotify(state middleware.State) bool { +func (z *Zone) isNotify(state request.Request) bool { if state.Req.Opcode != dns.OpcodeNotify { return false } @@ -56,7 +57,7 @@ func notify(zone string, to []string) error { func notifyAddr(c *dns.Client, m *dns.Msg, s string) error { for i := 0; i < 3; i++ { - ret, err := middleware.Exchange(c, m, s) + ret, _, err := c.Exchange(m, s) if err != nil { continue } diff --git a/middleware/file/secondary.go b/middleware/file/secondary.go index 5b1fe0cf2..9493a7fdd 100644 --- a/middleware/file/secondary.go +++ b/middleware/file/secondary.go @@ -4,8 +4,6 @@ import ( "log" "time" - "github.com/miekg/coredns/middleware" - "github.com/miekg/dns" ) @@ -75,7 +73,7 @@ func (z *Zone) shouldTransfer() (bool, error) { Transfer: for _, tr := range z.TransferFrom { Err = nil - ret, err := middleware.Exchange(c, m, tr) + ret, _, err := c.Exchange(m, tr) if err != nil || ret.Rcode != dns.RcodeSuccess { Err = err continue diff --git a/middleware/file/secondary_test.go b/middleware/file/secondary_test.go index 8eff84efb..b32c7aca7 100644 --- a/middleware/file/secondary_test.go +++ b/middleware/file/secondary_test.go @@ -4,8 +4,8 @@ import ( "fmt" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -134,7 +134,7 @@ func TestIsNotify(t *testing.T) { z := new(Zone) z.Expired = new(bool) z.origin = testZone - state := NewState(testZone, dns.TypeSOA) + state := newRequest(testZone, dns.TypeSOA) // need to set opcode state.Req.Opcode = dns.OpcodeNotify @@ -148,9 +148,9 @@ func TestIsNotify(t *testing.T) { } } -func NewState(zone string, qtype uint16) middleware.State { +func newRequest(zone string, qtype uint16) request.Request { m := new(dns.Msg) m.SetQuestion("example.com.", dns.TypeA) m.SetEdns0(4097, true) - return middleware.State{W: &test.ResponseWriter{}, Req: m} + return request.Request{W: &test.ResponseWriter{}, Req: m} } diff --git a/middleware/file/tree/elem.go b/middleware/file/tree/elem.go index 6785a6849..785d64660 100644 --- a/middleware/file/tree/elem.go +++ b/middleware/file/tree/elem.go @@ -1,9 +1,6 @@ package tree -import ( - "github.com/miekg/coredns/middleware" - "github.com/miekg/dns" -) +import "github.com/miekg/dns" type Elem struct { m map[uint16][]dns.RR @@ -91,8 +88,8 @@ func (e *Elem) Delete(rr dns.RR) (empty bool) { return } -// Less is a tree helper function that calls middleware.Less. -func Less(a *Elem, name string) int { return middleware.Less(name, a.Name()) } +// Less is a tree helper function that calls less. +func Less(a *Elem, name string) int { return less(name, a.Name()) } // Assuming the same type and name this will check if the rdata is equal as well. func equalRdata(a, b dns.RR) bool { diff --git a/middleware/canonical.go b/middleware/file/tree/less.go similarity index 91% rename from middleware/canonical.go rename to middleware/file/tree/less.go index fd30946bf..32d87b683 100644 --- a/middleware/canonical.go +++ b/middleware/file/tree/less.go @@ -1,4 +1,4 @@ -package middleware +package tree import ( "bytes" @@ -6,7 +6,7 @@ import ( "github.com/miekg/dns" ) -// Less returns <0 when a is less than b, 0 when they are equal and +// less returns <0 when a is less than b, 0 when they are equal and // >0 when a is larger than b. // The function orders names in DNSSEC canonical order: RFC 4034s section-6.1 // @@ -14,7 +14,7 @@ import ( // for a blog article on this implementation. // // The values of a and b are *not* lowercased before the comparison! -func Less(a, b string) int { +func less(a, b string) int { i := 1 aj := len(a) bj := len(b) diff --git a/middleware/canonical_test.go b/middleware/file/tree/less_test.go similarity index 96% rename from middleware/canonical_test.go rename to middleware/file/tree/less_test.go index 1a9fd0859..419b75c55 100644 --- a/middleware/canonical_test.go +++ b/middleware/file/tree/less_test.go @@ -1,4 +1,4 @@ -package middleware +package tree import ( "sort" @@ -10,7 +10,7 @@ type set []string func (p set) Len() int { return len(p) } func (p set) Swap(i, j int) { p[i], p[j] = p[j], p[i] } -func (p set) Less(i, j int) bool { d := Less(p[i], p[j]); return d <= 0 } +func (p set) Less(i, j int) bool { d := less(p[i], p[j]); return d <= 0 } func TestLess(t *testing.T) { tests := []struct { diff --git a/middleware/file/wildcard_test.go b/middleware/file/wildcard_test.go index cff1292cc..b0cc4c610 100644 --- a/middleware/file/wildcard_test.go +++ b/middleware/file/wildcard_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -57,14 +57,14 @@ func TestLookupWildcard(t *testing.T) { for _, tc := range wildcardTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/file/xfr.go b/middleware/file/xfr.go index 1d87a244b..d5ea1f2d3 100644 --- a/middleware/file/xfr.go +++ b/middleware/file/xfr.go @@ -4,7 +4,7 @@ import ( "fmt" "log" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -18,7 +18,7 @@ type ( // Serve an AXFR (and fallback of IXFR) as well. func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} if !x.TransferAllowed(state) { return dns.RcodeServerFailure, nil } diff --git a/middleware/file/zone.go b/middleware/file/zone.go index b84162cbb..ff8795d77 100644 --- a/middleware/file/zone.go +++ b/middleware/file/zone.go @@ -8,8 +8,8 @@ import ( "strings" "sync" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/file/tree" + "github.com/miekg/coredns/request" "github.com/fsnotify/fsnotify" "github.com/miekg/dns" @@ -102,7 +102,7 @@ func (z *Zone) Insert(r dns.RR) error { func (z *Zone) Delete(r dns.RR) { z.Tree.Delete(r) } // TransferAllowed checks if incoming request for transferring the zone is allowed according to the ACLs. -func (z *Zone) TransferAllowed(state middleware.State) bool { +func (z *Zone) TransferAllowed(req request.Request) bool { for _, t := range z.TransferTo { if t == "*" { return true diff --git a/middleware/fs_test.go b/middleware/fs_test.go deleted file mode 100644 index 44133c4eb..000000000 --- a/middleware/fs_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package middleware - -import ( - "os" - "strings" - "testing" -) - -func TestfsPath(t *testing.T) { - if actual := fsPath(); !strings.HasSuffix(actual, ".coredns") { - t.Errorf("Expected path to be a .coredns folder, got: %v", actual) - } - - os.Setenv("COREDNSPATH", "testpath") - defer os.Setenv("COREDNSPATH", "") - if actual, expected := fsPath(), "testpath"; actual != expected { - t.Errorf("Expected path to be %v, got: %v", expected, actual) - } -} diff --git a/middleware/host.go b/middleware/host.go deleted file mode 100644 index 65dbefbca..000000000 --- a/middleware/host.go +++ /dev/null @@ -1,36 +0,0 @@ -package middleware - -import ( - "net" - "strings" - - "github.com/miekg/dns" -) - -// Host represents a host from the Corefile, may contain port. -type ( - Host string - Addr string -) - -// Normalize will return the host portion of host, stripping -// of any port. The host will also be fully qualified and lowercased. -func (h Host) Normalize() string { - // separate host and port - host, _, err := net.SplitHostPort(string(h)) - if err != nil { - host, _, _ = net.SplitHostPort(string(h) + ":") - } - return strings.ToLower(dns.Fqdn(host)) -} - -// Normalize will return a normalized address, if not port is specified -// port 53 is added, otherwise the port will be left as is. -func (a Addr) Normalize() string { - // separate host and port - addr, port, err := net.SplitHostPort(string(a)) - if err != nil { - addr, port, _ = net.SplitHostPort(string(a) + ":53") - } - return net.JoinHostPort(addr, port) -} diff --git a/middleware/kubernetes/handler.go b/middleware/kubernetes/handler.go index 1986820d5..a89dedc0f 100644 --- a/middleware/kubernetes/handler.go +++ b/middleware/kubernetes/handler.go @@ -2,16 +2,17 @@ package kubernetes import ( "fmt" - "strings" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsutil" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" ) func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassINET { return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET") } @@ -21,8 +22,8 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true // TODO: find an alternative to this block - if strings.HasSuffix(state.Name(), arpaSuffix) { - ip, _ := extractIP(state.Name()) + ip := dnsutil.ExtractAddressFromReverse(state.Name()) + if ip != "" { records := k.getServiceRecordForIP(ip, state.Name()) if len(records) > 0 { srvPTR := &records[0] @@ -100,7 +101,7 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M } // NoData write a nodata response to the client. -func (k Kubernetes) Err(zone string, rcode int, state middleware.State) (int, error) { +func (k Kubernetes) Err(zone string, rcode int, state request.Request) (int, error) { m := new(dns.Msg) m.SetRcode(state.Req, rcode) m.Ns = []dns.RR{k.SOA(zone, state)} diff --git a/middleware/kubernetes/kubernetes.go b/middleware/kubernetes/kubernetes.go index 569e089e0..0bd1dc7a4 100644 --- a/middleware/kubernetes/kubernetes.go +++ b/middleware/kubernetes/kubernetes.go @@ -4,13 +4,13 @@ package kubernetes import ( "errors" "log" - "strings" "time" "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/kubernetes/msg" "github.com/miekg/coredns/middleware/kubernetes/nametemplate" "github.com/miekg/coredns/middleware/kubernetes/util" + "github.com/miekg/coredns/middleware/pkg/dnsutil" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/dns" @@ -100,8 +100,8 @@ func (k *Kubernetes) getZoneForName(name string) (string, []string) { func (k *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) { // TODO: refector this. // Right now GetNamespaceFromSegmentArray do not supports PRE queries - if strings.HasSuffix(name, arpaSuffix) { - ip, _ := extractIP(name) + ip := dnsutil.ExtractAddressFromReverse(name) + if ip != "" { records := k.getServiceRecordForIP(ip, name) return records, nil } diff --git a/middleware/kubernetes/lookup.go b/middleware/kubernetes/lookup.go index 0096e1fdb..e14d2275e 100644 --- a/middleware/kubernetes/lookup.go +++ b/middleware/kubernetes/lookup.go @@ -4,21 +4,17 @@ import ( "fmt" "math" "net" - "strings" "time" "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/kubernetes/msg" + "github.com/miekg/coredns/middleware/pkg/dnsutil" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) -const ( - // arpaSuffix is the standard suffix for PTR IP reverse lookups. - arpaSuffix = ".in-addr.arpa." -) - -func (k Kubernetes) records(state middleware.State, exact bool) ([]msg.Service, error) { +func (k Kubernetes) records(state request.Request, exact bool) ([]msg.Service, error) { services, err := k.Records(state.Name(), exact) if err != nil { return nil, err @@ -28,7 +24,7 @@ func (k Kubernetes) records(state middleware.State, exact bool) ([]msg.Service, return services, nil } -func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) { +func (k Kubernetes) A(zone string, state request.Request, previousRecords []dns.RR) (records []dns.RR, err error) { services, err := k.records(state, false) if err != nil { return nil, err @@ -49,11 +45,11 @@ func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns // don't add it, and just continue continue } - if isDuplicateCNAME(newRecord, previousRecords) { + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { continue } - state1 := copyState(state, serv.Host, state.QType()) + state1 := state.NewWithQuestion(serv.Host, state.QType()) nextRecords, err := k.A(zone, state1, append(previousRecords, newRecord)) if err == nil { @@ -87,7 +83,7 @@ func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns return records, nil } -func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) { +func (k Kubernetes) AAAA(zone string, state request.Request, previousRecords []dns.RR) (records []dns.RR, err error) { services, err := k.records(state, false) if err != nil { return nil, err @@ -108,11 +104,11 @@ func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords [] // don't add it, and just continue continue } - if isDuplicateCNAME(newRecord, previousRecords) { + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { continue } - state1 := copyState(state, serv.Host, state.QType()) + state1 := state.NewWithQuestion(serv.Host, state.QType()) nextRecords, err := k.AAAA(zone, state1, append(previousRecords, newRecord)) if err == nil { @@ -149,7 +145,7 @@ func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords [] // SRV returns SRV records from kubernetes. // If the Target is not a name but an IP address, a name is created on the fly. -func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR, extra []dns.RR, err error) { +func (k Kubernetes) SRV(zone string, state request.Request) (records []dns.RR, extra []dns.RR, err error) { services, err := k.records(state, false) if err != nil { return nil, nil, err @@ -207,7 +203,7 @@ func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR, } // Internal name, we should have some info on them, either v4 or v6 // Clients expect a complete answer, because we are a recursor in their view. - state1 := copyState(state, srv.Target, dns.TypeA) + state1 := state.NewWithQuestion(srv.Target, dns.TypeA) addr, e1 := k.A(zone, state1, nil) if e1 == nil { extra = append(extra, addr...) @@ -231,21 +227,21 @@ func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR, } // Returning MX records from kubernetes not implemented. -func (k Kubernetes) MX(zone string, state middleware.State) (records []dns.RR, extra []dns.RR, err error) { +func (k Kubernetes) MX(zone string, state request.Request) (records []dns.RR, extra []dns.RR, err error) { return nil, nil, err } // Returning CNAME records from kubernetes not implemented. -func (k Kubernetes) CNAME(zone string, state middleware.State) (records []dns.RR, err error) { +func (k Kubernetes) CNAME(zone string, state request.Request) (records []dns.RR, err error) { return nil, err } // Returning TXT records from kubernetes not implemented. -func (k Kubernetes) TXT(zone string, state middleware.State) (records []dns.RR, err error) { +func (k Kubernetes) TXT(zone string, state request.Request) (records []dns.RR, err error) { return nil, err } -func (k Kubernetes) NS(zone string, state middleware.State) (records, extra []dns.RR, err error) { +func (k Kubernetes) NS(zone string, state request.Request) (records, extra []dns.RR, err error) { // NS record for this zone live in a special place, ns.dns.. Fake our lookup. // only a tad bit fishy... old := state.QName() @@ -278,7 +274,7 @@ func (k Kubernetes) NS(zone string, state middleware.State) (records, extra []dn } // SOA Record returns a SOA record. -func (k Kubernetes) SOA(zone string, state middleware.State) *dns.SOA { +func (k Kubernetes) SOA(zone string, state request.Request) *dns.SOA { header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: 300, Class: dns.ClassINET} return &dns.SOA{Hdr: header, Mbox: "hostmaster." + zone, @@ -291,9 +287,9 @@ func (k Kubernetes) SOA(zone string, state middleware.State) *dns.SOA { } } -func (k Kubernetes) PTR(zone string, state middleware.State) ([]dns.RR, error) { - reverseIP, ok := extractIP(state.Name()) - if !ok { +func (k Kubernetes) PTR(zone string, state request.Request) ([]dns.RR, error) { + reverseIP := dnsutil.ExtractAddressFromReverse(state.Name()) + if reverseIP == "" { return nil, fmt.Errorf("does not support reverse lookup for %s", state.QName()) } @@ -318,41 +314,3 @@ func (k Kubernetes) PTR(zone string, state middleware.State) ([]dns.RR, error) { } return records, nil } - -func isDuplicateCNAME(r *dns.CNAME, records []dns.RR) bool { - for _, rec := range records { - if v, ok := rec.(*dns.CNAME); ok { - if v.Target == r.Target { - return true - } - } - } - return false -} - -func copyState(state middleware.State, target string, typ uint16) middleware.State { - state1 := middleware.State{W: state.W, Req: state.Req.Copy()} - state1.Req.Question[0] = dns.Question{Name: dns.Fqdn(target), Qtype: dns.ClassINET, Qclass: typ} - return state1 -} - -// extractIP turns a standard PTR reverse record lookup name -// into an IP address -func extractIP(reverseName string) (string, bool) { - if !strings.HasSuffix(reverseName, arpaSuffix) { - return "", false - } - search := strings.TrimSuffix(reverseName, arpaSuffix) - - // reverse the segments and then combine them - segments := reverseArray(strings.Split(search, ".")) - return strings.Join(segments, "."), true -} - -func reverseArray(arr []string) []string { - for i := 0; i < len(arr)/2; i++ { - j := len(arr) - i - 1 - arr[i], arr[j] = arr[j], arr[i] - } - return arr -} diff --git a/middleware/kubernetes/setup.go b/middleware/kubernetes/setup.go index fc3c036b8..09e4478e3 100644 --- a/middleware/kubernetes/setup.go +++ b/middleware/kubernetes/setup.go @@ -68,7 +68,7 @@ func kubernetesParse(c *caddy.Controller) (Kubernetes, error) { } k8s.Zones = NormalizeZoneList(zones) - middleware.Zones(k8s.Zones).FullyQualify() + middleware.Zones(k8s.Zones).Normalize() if k8s.Zones == nil || len(k8s.Zones) < 1 { err = errors.New("Zone name must be provided for kubernetes middleware.") diff --git a/middleware/loadbalance/loadbalance_test.go b/middleware/loadbalance/loadbalance_test.go index 5f1ff2ecb..5e240be13 100644 --- a/middleware/loadbalance/loadbalance_test.go +++ b/middleware/loadbalance/loadbalance_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -56,7 +57,7 @@ func TestLoadBalance(t *testing.T) { }, } - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) for i, test := range tests { req := new(dns.Msg) @@ -71,7 +72,7 @@ func TestLoadBalance(t *testing.T) { } cname := 0 - for _, r := range rec.Msg().Answer { + for _, r := range rec.Msg.Answer { if r.Header().Rrtype != dns.TypeCNAME { break } @@ -81,7 +82,7 @@ func TestLoadBalance(t *testing.T) { t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname) } cname = 0 - for _, r := range rec.Msg().Extra { + for _, r := range rec.Msg.Extra { if r.Header().Rrtype != dns.TypeCNAME { break } diff --git a/middleware/log/log.go b/middleware/log/log.go index 32d40632a..9da5e65aa 100644 --- a/middleware/log/log.go +++ b/middleware/log/log.go @@ -7,6 +7,11 @@ import ( "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/metrics" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" + "github.com/miekg/coredns/middleware/pkg/rcode" + "github.com/miekg/coredns/middleware/pkg/replacer" + "github.com/miekg/coredns/middleware/pkg/roller" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -20,32 +25,30 @@ type Logger struct { } func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} for _, rule := range l.Rules { if middleware.Name(rule.NameScope).Matches(state.Name()) { - responseRecorder := middleware.NewResponseRecorder(w) - rcode, err := l.Next.ServeDNS(ctx, responseRecorder, r) + responseRecorder := dnsrecorder.New(w) + rc, err := l.Next.ServeDNS(ctx, responseRecorder, r) - if rcode > 0 { + if rc > 0 { // There was an error up the chain, but no response has been written yet. // The error must be handled here so the log entry will record the response size. if l.ErrorFunc != nil { - l.ErrorFunc(responseRecorder, r, rcode) + l.ErrorFunc(responseRecorder, r, rc) } else { - rc := middleware.RcodeToString(rcode) - answer := new(dns.Msg) - answer.SetRcode(r, rcode) + answer.SetRcode(r, rc) state.SizeAndDo(answer) - metrics.Report(state, metrics.Dropped, rc, answer.Len(), time.Now()) + metrics.Report(state, metrics.Dropped, rcode.ToString(rc), answer.Len(), time.Now()) w.WriteMsg(answer) } - rcode = 0 + rc = 0 } - rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue) + rep := replacer.New(r, responseRecorder, CommonLogEmptyValue) rule.Log.Println(rep.Replace(rule.Format)) - return rcode, err + return rc, err } } @@ -58,7 +61,7 @@ type Rule struct { OutputFile string Format string Log *log.Logger - Roller *middleware.LogRoller + Roller *roller.LogRoller } const ( diff --git a/middleware/log/log_test.go b/middleware/log/log_test.go index b5df1ad76..77e3f2e3c 100644 --- a/middleware/log/log_test.go +++ b/middleware/log/log_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -37,7 +37,7 @@ func TestLoggedStatus(t *testing.T) { r := new(dns.Msg) r.SetQuestion("example.org.", dns.TypeA) - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) rcode, _ := logger.ServeDNS(ctx, rec, r) if rcode != 0 { diff --git a/middleware/log/setup.go b/middleware/log/setup.go index a1e143e6b..0721090a9 100644 --- a/middleware/log/setup.go +++ b/middleware/log/setup.go @@ -6,7 +6,7 @@ import ( "os" "github.com/miekg/coredns/core/dnsserver" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/roller" "github.com/hashicorp/go-syslog" "github.com/mholt/caddy" @@ -75,13 +75,13 @@ func logParse(c *caddy.Controller) ([]Rule, error) { for c.Next() { args := c.RemainingArgs() - var logRoller *middleware.LogRoller + var logRoller *roller.LogRoller if c.NextBlock() { if c.Val() == "rotate" { if c.NextArg() { if c.Val() == "{" { var err error - logRoller, err = middleware.ParseRoller(c) + logRoller, err = roller.Parse(c) if err != nil { return nil, err } diff --git a/middleware/log/setup_test.go b/middleware/log/setup_test.go index 0a3ee63fe..b38caa52d 100644 --- a/middleware/log/setup_test.go +++ b/middleware/log/setup_test.go @@ -3,7 +3,7 @@ package log import ( "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/roller" "github.com/mholt/caddy" ) @@ -68,7 +68,7 @@ func TestLogParse(t *testing.T) { NameScope: ".", OutputFile: "access.log", Format: DefaultLogFormat, - Roller: &middleware.LogRoller{ + Roller: &roller.LogRoller{ MaxSize: 2, MaxAge: 10, MaxBackups: 3, diff --git a/middleware/metrics/handler.go b/middleware/metrics/handler.go index 1bdc34897..6cac7e81e 100644 --- a/middleware/metrics/handler.go +++ b/middleware/metrics/handler.go @@ -4,13 +4,16 @@ import ( "time" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" + "github.com/miekg/coredns/middleware/pkg/rcode" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" ) func (m Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} qname := state.QName() zone := middleware.Zones(m.ZoneNames).Matches(qname) @@ -19,35 +22,35 @@ func (m Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) } // Record response to get status code and size of the reply. - rw := middleware.NewResponseRecorder(w) + rw := dnsrecorder.New(w) status, err := m.Next.ServeDNS(ctx, rw, r) - Report(state, zone, rw.Rcode(), rw.Size(), rw.Start()) + Report(state, zone, rcode.ToString(rw.Rcode), rw.Size, rw.Start) return status, err } // Report is a plain reporting function that the server can use for REFUSED and other // queries that are turned down because they don't match any middleware. -func Report(state middleware.State, zone, rcode string, size int, start time.Time) { +func Report(req request.Request, zone, rcode string, size int, start time.Time) { if requestCount == nil { // no metrics are enabled return } // Proto and Family - net := state.Proto() + net := req.Proto() fam := "1" - if state.Family() == 2 { + if req.Family() == 2 { fam = "2" } - typ := state.QType() + typ := req.QType() requestCount.WithLabelValues(zone, net, fam).Inc() requestDuration.WithLabelValues(zone).Observe(float64(time.Since(start) / time.Millisecond)) - if state.Do() { + if req.Do() { requestDo.WithLabelValues(zone).Inc() } @@ -59,10 +62,10 @@ func Report(state middleware.State, zone, rcode string, size int, start time.Tim if typ == dns.TypeIXFR || typ == dns.TypeAXFR { responseTransferSize.WithLabelValues(zone, net).Observe(float64(size)) - requestTransferSize.WithLabelValues(zone, net).Observe(float64(state.Size())) + requestTransferSize.WithLabelValues(zone, net).Observe(float64(req.Size())) } else { responseSize.WithLabelValues(zone, net).Observe(float64(size)) - requestSize.WithLabelValues(zone, net).Observe(float64(state.Size())) + requestSize.WithLabelValues(zone, net).Observe(float64(req.Size())) } responseRcode.WithLabelValues(zone, rcode).Inc() diff --git a/middleware/name.go b/middleware/name.go deleted file mode 100644 index 616ae68bd..000000000 --- a/middleware/name.go +++ /dev/null @@ -1,25 +0,0 @@ -package middleware - -import ( - "strings" - - "github.com/miekg/dns" -) - -// Name represents a domain name. -type Name string - -// Matches checks to see if other is a subdomain (or the same domain) of n. -// This method assures that names can be easily and consistently matched. -func (n Name) Matches(child string) bool { - if dns.Name(n) == dns.Name(child) { - return true - } - - return dns.IsSubDomain(string(n), child) -} - -// Normalize lowercases and makes n fully qualified. -func (n Name) Normalize() string { - return strings.ToLower(dns.Fqdn(string(n))) -} diff --git a/middleware/normalize.go b/middleware/normalize.go new file mode 100644 index 000000000..e5b747620 --- /dev/null +++ b/middleware/normalize.go @@ -0,0 +1,78 @@ +package middleware + +import ( + "net" + "strings" + + "github.com/miekg/dns" +) + +type Zones []string + +// Matches checks is qname is a subdomain of any of the zones in z. The match +// will return the most specific zones that matches other. The empty string +// signals a not found condition. +func (z Zones) Matches(qname string) string { + zone := "" + for _, zname := range z { + if dns.IsSubDomain(zname, qname) { + // TODO(miek): hmm, add test for this case + if len(zname) > len(zone) { + zone = zname + } + } + } + return zone +} + +// Normalize fully qualifies all zones in z. +func (z Zones) Normalize() { + for i, _ := range z { + z[i] = Name(z[i]).Normalize() + } +} + +// Name represents a domain name. +type Name string + +// Matches checks to see if other is a subdomain (or the same domain) of n. +// This method assures that names can be easily and consistently matched. +func (n Name) Matches(child string) bool { + if dns.Name(n) == dns.Name(child) { + return true + } + + return dns.IsSubDomain(string(n), child) +} + +// Normalize lowercases and makes n fully qualified. +func (n Name) Normalize() string { return strings.ToLower(dns.Fqdn(string(n))) } + +// Host represents a host from the Corefile, may contain port. +type ( + Host string + Addr string +) + +// Normalize will return the host portion of host, stripping +// of any port. The host will also be fully qualified and lowercased. +func (h Host) Normalize() string { + // separate host and port + host, _, err := net.SplitHostPort(string(h)) + if err != nil { + host, _, _ = net.SplitHostPort(string(h) + ":") + } + return Name(host).Normalize() +} + +// Normalize will return a normalized address, if not port is specified +// port 53 is added, otherwise the port will be left as is. +func (a Addr) Normalize() string { + // separate host and port + addr, port, err := net.SplitHostPort(string(a)) + if err != nil { + addr, port, _ = net.SplitHostPort(string(a) + ":53") + } + // TODO(miek): lowercase it? + return net.JoinHostPort(addr, port) +} diff --git a/middleware/pkg/dnsrecorder/recorder.go b/middleware/pkg/dnsrecorder/recorder.go new file mode 100644 index 000000000..9bf045e91 --- /dev/null +++ b/middleware/pkg/dnsrecorder/recorder.go @@ -0,0 +1,57 @@ +package dnsrecorder + +import ( + "time" + + "github.com/miekg/dns" +) + +// Recorder is a type of ResponseWriter that captures +// the rcode code written to it and also the size of the message +// written in the response. A rcode code does not have +// to be written, however, in which case 0 must be assumed. +// It is best to have the constructor initialize this type +// with that default status code. +type Recorder struct { + dns.ResponseWriter + Rcode int + Size int + Msg *dns.Msg + Start time.Time +} + +// New makes and returns a new Recorder, +// which captures the DNS rcode from the ResponseWriter +// and also the length of the response message written through it. +func New(w dns.ResponseWriter) *Recorder { + return &Recorder{ + ResponseWriter: w, + Rcode: 0, + Msg: nil, + Start: time.Now(), + } +} + +// WriteMsg records the status code and calls the +// underlying ResponseWriter's WriteMsg method. +func (r *Recorder) WriteMsg(res *dns.Msg) error { + r.Rcode = res.Rcode + // We may get called multiple times (axfr for instance). + // Save the last message, but add the sizes. + r.Size += res.Len() + r.Msg = res + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the size of the message that gets written. +func (r *Recorder) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + if err == nil { + r.Size += n + } + return n, err +} + +// Hijack implements dns.Hijacker. It simply wraps the underlying +// ResponseWriter's Hijack method if there is one, or returns an error. +func (r *Recorder) Hijack() { r.ResponseWriter.Hijack(); return } diff --git a/middleware/recorder_test.go b/middleware/pkg/dnsrecorder/recorder_test.go similarity index 97% rename from middleware/recorder_test.go rename to middleware/pkg/dnsrecorder/recorder_test.go index 30931a0e3..c9c2f6ce4 100644 --- a/middleware/recorder_test.go +++ b/middleware/pkg/dnsrecorder/recorder_test.go @@ -1,4 +1,4 @@ -package middleware +package dnsrecorder /* func TestNewResponseRecorder(t *testing.T) { diff --git a/middleware/pkg/dnsutil/cname.go b/middleware/pkg/dnsutil/cname.go new file mode 100644 index 000000000..281e03218 --- /dev/null +++ b/middleware/pkg/dnsutil/cname.go @@ -0,0 +1,15 @@ +package dnsutil + +import "github.com/miekg/dns" + +// DuplicateCNAME returns true if r already exists in records. +func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool { + for _, rec := range records { + if v, ok := rec.(*dns.CNAME); ok { + if v.Target == r.Target { + return true + } + } + } + return false +} diff --git a/middleware/pkg/dnsutil/reverse.go b/middleware/pkg/dnsutil/reverse.go new file mode 100644 index 000000000..a360432f3 --- /dev/null +++ b/middleware/pkg/dnsutil/reverse.go @@ -0,0 +1,40 @@ +package dnsutil + +import "strings" + +// ExtractAddressFromReverse turns a standard PTR reverse record name +// into an IP address. This works for ipv4 or ipv6. +// +// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion +// failes the empty string is returned. +func ExtractAddressFromReverse(reverseName string) string { + search := "" + + switch { + case strings.HasSuffix(reverseName, v4arpaSuffix): + search = strings.TrimSuffix(reverseName, v4arpaSuffix) + case strings.HasSuffix(reverseName, v6arpaSuffix): + search = strings.TrimSuffix(reverseName, v6arpaSuffix) + default: + return "" + } + + // Reverse the segments and then combine them. + segments := reverse(strings.Split(search, ".")) + return strings.Join(segments, ".") +} + +func reverse(slice []string) []string { + for i := 0; i < len(slice)/2; i++ { + j := len(slice) - i - 1 + slice[i], slice[j] = slice[j], slice[i] + } + return slice +} + +const ( + // v4arpaSuffix is the reverse tree suffix for v4 IP addresses. + v4arpaSuffix = ".in-addr.arpa." + // v6arpaSuffix is the reverse tree suffix for v6 IP addresses. + v6arpaSuffix = ".ip6.arpa." +) diff --git a/middleware/edns.go b/middleware/pkg/edns/edns.go similarity index 76% rename from middleware/edns.go rename to middleware/pkg/edns/edns.go index c60bbc44d..6704066b0 100644 --- a/middleware/edns.go +++ b/middleware/pkg/edns/edns.go @@ -1,4 +1,4 @@ -package middleware +package edns import ( "errors" @@ -6,11 +6,11 @@ import ( "github.com/miekg/dns" ) -// Edns0Version checks the EDNS version in the request. If error +// Version checks the EDNS version in the request. If error // is nil everything is OK and we can invoke the middleware. If non-nil, the // returned Msg is valid to be returned to the client (and should). For some // reason this response should not contain a question RR in the question section. -func Edns0Version(req *dns.Msg) (*dns.Msg, error) { +func Version(req *dns.Msg) (*dns.Msg, error) { opt := req.IsEdns0() if opt == nil { return nil, nil @@ -33,8 +33,8 @@ func Edns0Version(req *dns.Msg) (*dns.Msg, error) { return m, errors.New("EDNS0 BADVERS") } -// edns0Size returns a normalized size based on proto. -func edns0Size(proto string, size int) int { +// Size returns a normalized size based on proto. +func Size(proto string, size int) int { if proto == "tcp" { return dns.MaxMsgSize } diff --git a/middleware/edns_test.go b/middleware/pkg/edns/edns_test.go similarity index 75% rename from middleware/edns_test.go rename to middleware/pkg/edns/edns_test.go index 7b4e6fc66..89ac6d2ec 100644 --- a/middleware/edns_test.go +++ b/middleware/pkg/edns/edns_test.go @@ -1,4 +1,4 @@ -package middleware +package edns import ( "testing" @@ -6,21 +6,21 @@ import ( "github.com/miekg/dns" ) -func TestEdns0Version(t *testing.T) { +func TestVersion(t *testing.T) { m := ednsMsg() m.Extra[0].(*dns.OPT).SetVersion(2) - _, err := Edns0Version(m) + _, err := Version(m) if err == nil { t.Errorf("expected wrong version, but got OK") } } -func TestEdns0VersionNoEdns(t *testing.T) { +func TestVersionNoEdns(t *testing.T) { m := ednsMsg() m.Extra = nil - _, err := Edns0Version(m) + _, err := Version(m) if err != nil { t.Errorf("expected no error, but got one: %s", err) } diff --git a/middleware/rcode.go b/middleware/pkg/rcode/rcode.go similarity index 72% rename from middleware/rcode.go rename to middleware/pkg/rcode/rcode.go index 989f90fdd..006440071 100644 --- a/middleware/rcode.go +++ b/middleware/pkg/rcode/rcode.go @@ -1,4 +1,4 @@ -package middleware +package rcode import ( "strconv" @@ -6,7 +6,7 @@ import ( "github.com/miekg/dns" ) -func RcodeToString(rcode int) string { +func ToString(rcode int) string { if str, ok := dns.RcodeToString[rcode]; ok { return str } diff --git a/middleware/replacer.go b/middleware/pkg/replacer/replacer.go similarity index 76% rename from middleware/replacer.go rename to middleware/pkg/replacer/replacer.go index a2fac3113..e90f29368 100644 --- a/middleware/replacer.go +++ b/middleware/pkg/replacer/replacer.go @@ -1,10 +1,13 @@ -package middleware +package replacer import ( "strconv" "strings" "time" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" + "github.com/miekg/coredns/request" + "github.com/miekg/dns" ) @@ -22,46 +25,43 @@ type replacer struct { emptyValue string } -// NewReplacer makes a new replacer based on r and rr. +// New makes a new replacer based on r and rr. // Do not create a new replacer until r and rr have all // the needed values, because this function copies those // values into the replacer. rr may be nil if it is not // available. emptyValue should be the string that is used // in place of empty string (can still be empty string). -func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer { - state := State{W: rr, Req: r} +func New(r *dns.Msg, rr *dnsrecorder.Recorder, emptyValue string) Replacer { + req := request.Request{W: rr, Req: r} rep := replacer{ replacements: map[string]string{ - "{type}": state.Type(), - "{name}": state.Name(), - "{class}": state.Class(), - "{proto}": state.Proto(), + "{type}": req.Type(), + "{name}": req.Name(), + "{class}": req.Class(), + "{proto}": req.Proto(), "{when}": func() string { return time.Now().Format(timeFormat) }(), - "{remote}": state.IP(), - "{port}": func() string { - p, _ := state.Port() - return p - }(), + "{remote}": req.IP(), + "{port}": req.Port(), }, emptyValue: emptyValue, } if rr != nil { - rcode := dns.RcodeToString[rr.rcode] + rcode := dns.RcodeToString[rr.Rcode] if rcode == "" { - rcode = strconv.Itoa(rr.rcode) + rcode = strconv.Itoa(rr.Rcode) } rep.replacements["{rcode}"] = rcode - rep.replacements["{size}"] = strconv.Itoa(rr.size) - rep.replacements["{duration}"] = time.Since(rr.start).String() + rep.replacements["{size}"] = strconv.Itoa(rr.Size) + rep.replacements["{duration}"] = time.Since(rr.Start).String() } // Header placeholders (case-insensitive) rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id)) rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(int(r.Opcode)) - rep.replacements[headerReplacer+"do}"] = boolToString(state.Do()) - rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(state.Size()) + rep.replacements[headerReplacer+"do}"] = boolToString(req.Do()) + rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(req.Size()) return rep } diff --git a/middleware/replacer_test.go b/middleware/pkg/replacer/replacer_test.go similarity index 99% rename from middleware/replacer_test.go rename to middleware/pkg/replacer/replacer_test.go index 378e4083d..09904c1a9 100644 --- a/middleware/replacer_test.go +++ b/middleware/pkg/replacer/replacer_test.go @@ -1,4 +1,4 @@ -package middleware +package replacer /* func TestNewReplacer(t *testing.T) { diff --git a/middleware/classify.go b/middleware/pkg/response/classify.go similarity index 62% rename from middleware/classify.go rename to middleware/pkg/response/classify.go index 72c131157..adbaa6526 100644 --- a/middleware/classify.go +++ b/middleware/pkg/response/classify.go @@ -1,19 +1,19 @@ -package middleware +package response import "github.com/miekg/dns" -type MsgType int +type Type int const ( - Success MsgType = iota - NameError // NXDOMAIN in header, SOA in auth. - NoData // NOERROR in header, SOA in auth. - Delegation // NOERROR in header, NS in auth, optionally fluff in additional (not checked). - OtherError // Don't cache these. + Success Type = iota + NameError // NXDOMAIN in header, SOA in auth. + NoData // NOERROR in header, SOA in auth. + Delegation // NOERROR in header, NS in auth, optionally fluff in additional (not checked). + OtherError // Don't cache these. ) -// Classify classifies a message, it returns the MessageType. -func Classify(m *dns.Msg) (MsgType, *dns.OPT) { +// Classify classifies a message, it returns the Type. +func Classify(m *dns.Msg) (Type, *dns.OPT) { opt := m.IsEdns0() if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess { diff --git a/middleware/classify_test.go b/middleware/pkg/response/classify_test.go similarity index 97% rename from middleware/classify_test.go rename to middleware/pkg/response/classify_test.go index 26c52db55..33ac815e3 100644 --- a/middleware/classify_test.go +++ b/middleware/pkg/response/classify_test.go @@ -1,4 +1,4 @@ -package middleware +package response import ( "testing" diff --git a/middleware/roller.go b/middleware/pkg/roller/roller.go similarity index 94% rename from middleware/roller.go rename to middleware/pkg/roller/roller.go index 81ff71c44..c60494736 100644 --- a/middleware/roller.go +++ b/middleware/pkg/roller/roller.go @@ -1,4 +1,4 @@ -package middleware +package roller import ( "io" @@ -8,7 +8,7 @@ import ( "gopkg.in/natefinch/lumberjack.v2" ) -func ParseRoller(c *caddy.Controller) (*LogRoller, error) { +func Parse(c *caddy.Controller) (*LogRoller, error) { var size, age, keep int // This is kind of a hack to support nested blocks: // As we are already in a block: either log or errors, diff --git a/singleflight/singleflight.go b/middleware/pkg/singleflight/singleflight.go similarity index 100% rename from singleflight/singleflight.go rename to middleware/pkg/singleflight/singleflight.go diff --git a/singleflight/singleflight_test.go b/middleware/pkg/singleflight/singleflight_test.go similarity index 100% rename from singleflight/singleflight_test.go rename to middleware/pkg/singleflight/singleflight_test.go diff --git a/middleware/fs.go b/middleware/pkg/storage/fs.go similarity index 64% rename from middleware/fs.go rename to middleware/pkg/storage/fs.go index 5970d0e99..3ee14b7ed 100644 --- a/middleware/fs.go +++ b/middleware/pkg/storage/fs.go @@ -1,25 +1,39 @@ -package middleware +package storage import ( "net/http" "os" + "path" "path/filepath" "runtime" ) -// dir wraps http.Dir that restrict file access to a specific directory tree. +// dir wraps an http.Dir that restrict file access to a specific directory tree, see http.Dir's documentation +// for methods for accessing files. type dir http.Dir // CoreDir is the directory where middleware can store assets, like zone files after a zone transfer // or public and private keys or anything else a middleware might need. The convention is to place -// assets in a subdirectory named after the fully qualified zone. +// assets in a subdirectory named after the zone prefixed with "D", to prevent the root zone become a hidden directory. // -// example.org./Kexample.key +// Dexample.org/Kexample.org.key +// +// Note that subzone(s) under example.org are places in the own directory under CoreDir: +// +// Dexample.org/... +// Db.example.org/... // // CoreDir will default to "$HOME/.coredns" on Unix, but it's location can be overriden with the COREDNSPATH // environment variable. var CoreDir dir = dir(fsPath()) +func (d dir) Zone(z string) dir { + if z != "." && z[len(z)-2] == '.' { + return dir(path.Join(string(d), "D"+z[:len(z)-1])) + } + return dir(path.Join(string(d), "D"+z)) +} + // fsPath returns the path to the directory where the application may store data. // If COREDNSPATH env variable. is set, that value is used. Otherwise, the path is // the result of evaluating "$HOME/.coredns". diff --git a/middleware/pkg/storage/fs_test.go b/middleware/pkg/storage/fs_test.go new file mode 100644 index 000000000..f7e8ccf9d --- /dev/null +++ b/middleware/pkg/storage/fs_test.go @@ -0,0 +1,42 @@ +package storage + +import ( + "os" + "path" + "strings" + "testing" +) + +func TestfsPath(t *testing.T) { + if actual := fsPath(); !strings.HasSuffix(actual, ".coredns") { + t.Errorf("Expected path to be a .coredns folder, got: %v", actual) + } + + os.Setenv("COREDNSPATH", "testpath") + defer os.Setenv("COREDNSPATH", "") + if actual, expected := fsPath(), "testpath"; actual != expected { + t.Errorf("Expected path to be %v, got: %v", expected, actual) + } +} + +func TestZone(t *testing.T) { + for _, ts := range []string{"example.org.", "example.org"} { + d := CoreDir.Zone(ts) + actual := path.Base(string(d)) + expected := "D" + ts + if actual != expected { + t.Errorf("Expected path to be %v, got %v", actual, expected) + } + } +} + +func TestZoneRoot(t *testing.T) { + for _, ts := range []string{"."} { + d := CoreDir.Zone(ts) + actual := path.Base(string(d)) + expected := "D" + ts + if actual != expected { + t.Errorf("Expected path to be %v, got %v", actual, expected) + } + } +} diff --git a/middleware/proxy/lookup.go b/middleware/proxy/lookup.go index 22a0c77d6..e2a3a0c77 100644 --- a/middleware/proxy/lookup.go +++ b/middleware/proxy/lookup.go @@ -7,7 +7,8 @@ import ( "sync/atomic" "time" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" + "github.com/miekg/dns" ) @@ -54,18 +55,19 @@ func New(hosts []string) Proxy { // Lookup will use name and type to forge a new message and will send that upstream. It will // set any EDNS0 options correctly so that downstream will be able to process the reply. // Lookup is not suitable for forwarding request. Ssee for that. -func (p Proxy) Lookup(state middleware.State, name string, tpe uint16) (*dns.Msg, error) { +func (p Proxy) Lookup(state request.Request, name string, tpe uint16) (*dns.Msg, error) { req := new(dns.Msg) req.SetQuestion(name, tpe) state.SizeAndDo(req) + return p.lookup(state, req) } -func (p Proxy) Forward(state middleware.State) (*dns.Msg, error) { +func (p Proxy) Forward(state request.Request) (*dns.Msg, error) { return p.lookup(state, state.Req) } -func (p Proxy) lookup(state middleware.State, r *dns.Msg) (*dns.Msg, error) { +func (p Proxy) lookup(state request.Request, r *dns.Msg) (*dns.Msg, error) { var ( reply *dns.Msg err error @@ -84,9 +86,9 @@ func (p Proxy) lookup(state middleware.State, r *dns.Msg) (*dns.Msg, error) { atomic.AddInt64(&host.Conns, 1) if state.Proto() == "tcp" { - reply, err = middleware.Exchange(p.Client.TCP, r, host.Name) + reply, _, err = p.Client.TCP.Exchange(r, host.Name) } else { - reply, err = middleware.Exchange(p.Client.UDP, r, host.Name) + reply, _, err = p.Client.UDP.Exchange(r, host.Name) } atomic.AddInt64(&host.Conns, -1) diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index 071452ecb..c452e296c 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -2,7 +2,7 @@ package proxy import ( - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -20,10 +20,10 @@ func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) ) switch { - case middleware.Proto(w) == "tcp": - reply, err = middleware.Exchange(p.Client.TCP, r, p.Host) + case request.Proto(w) == "tcp": // TODO(miek): keep this in request + reply, _, err = p.Client.TCP.Exchange(r, p.Host) default: - reply, err = middleware.Exchange(p.Client.UDP, r, p.Host) + reply, _, err = p.Client.UDP.Exchange(r, p.Host) } if reply != nil && reply.Truncated { diff --git a/middleware/recorder.go b/middleware/recorder.go deleted file mode 100644 index d1e466ec3..000000000 --- a/middleware/recorder.go +++ /dev/null @@ -1,72 +0,0 @@ -package middleware - -import ( - "time" - - "github.com/miekg/dns" -) - -// ResponseRecorder is a type of ResponseWriter that captures -// the rcode code written to it and also the size of the message -// written in the response. A rcode code does not have -// to be written, however, in which case 0 must be assumed. -// It is best to have the constructor initialize this type -// with that default status code. -type ResponseRecorder struct { - dns.ResponseWriter - rcode int - size int - msg *dns.Msg - start time.Time -} - -// NewResponseRecorder makes and returns a new responseRecorder, -// which captures the DNS rcode from the ResponseWriter -// and also the length of the response message written through it. -func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder { - return &ResponseRecorder{ - ResponseWriter: w, - rcode: 0, - msg: nil, - start: time.Now(), - } -} - -// WriteMsg records the status code and calls the -// underlying ResponseWriter's WriteMsg method. -func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error { - r.rcode = res.Rcode - // We may get called multiple times (axfr for instance). - // Save the last message, but add the sizes. - r.size += res.Len() - r.msg = res - return r.ResponseWriter.WriteMsg(res) -} - -// Write is a wrapper that records the size of the message that gets written. -func (r *ResponseRecorder) Write(buf []byte) (int, error) { - n, err := r.ResponseWriter.Write(buf) - if err == nil { - r.size += n - } - return n, err -} - -// Size returns the size. -func (r *ResponseRecorder) Size() int { return r.size } - -// Rcode returns the rcode. -func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) } - -// Start returns the start time of the ResponseRecorder. -func (r *ResponseRecorder) Start() time.Time { return r.start } - -// Msg returns the written message from the ResponseRecorder. -func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg } - -// Hijack implements dns.Hijacker. It simply wraps the underlying -// ResponseWriter's Hijack method if there is one, or returns an error. -func (r *ResponseRecorder) Hijack() { - r.ResponseWriter.Hijack() - return -} diff --git a/middleware/rewrite/condition.go b/middleware/rewrite/condition.go index 3ccde21f9..e9c565abc 100644 --- a/middleware/rewrite/condition.go +++ b/middleware/rewrite/condition.go @@ -5,7 +5,8 @@ import ( "regexp" "strings" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/replacer" + "github.com/miekg/dns" ) @@ -25,8 +26,8 @@ func operatorError(operator string) error { return fmt.Errorf("Invalid operator %v", operator) } -func newReplacer(r *dns.Msg) middleware.Replacer { - return middleware.NewReplacer(r, nil, "") +func newReplacer(r *dns.Msg) replacer.Replacer { + return replacer.New(r, nil, "") } // condition is a rewrite condition. diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index 9d8450a3e..2f2f404c9 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -117,143 +117,3 @@ func (s SimpleRule) Rewrite(r *dns.Msg) Result { } return RewriteIgnored } - -/* -// ComplexRule is a rewrite rule based on a regular expression -type ComplexRule struct { - // Path base. Request to this path and subpaths will be rewritten - Base string - - // Path to rewrite to - To string - - // If set, neither performs rewrite nor proceeds - // with request. Only returns code. - Status int - - // Extensions to filter by - Exts []string - - // Rewrite conditions - Ifs []If - - *regexp.Regexp -} - -// NewComplexRule creates a new RegexpRule. It returns an error if regexp -// pattern (pattern) or extensions (ext) are invalid. -func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) { - // validate regexp if present - var r *regexp.Regexp - if pattern != "" { - var err error - r, err = regexp.Compile(pattern) - if err != nil { - return nil, err - } - } - - // validate extensions if present - for _, v := range ext { - if len(v) < 2 || (len(v) < 3 && v[0] == '!') { - // check if no extension is specified - if v != "/" && v != "!/" { - return nil, fmt.Errorf("invalid extension %v", v) - } - } - } - - return &ComplexRule{ - Base: base, - To: to, - Status: status, - Exts: ext, - Ifs: ifs, - Regexp: r, - }, nil -} - -// Rewrite rewrites the internal location of the current request. -func (r *ComplexRule) Rewrite(req *dns.Msg) (re Result) { - rPath := req.URL.Path - replacer := newReplacer(req) - - // validate base - if !middleware.Path(rPath).Matches(r.Base) { - return - } - - // validate extensions - if !r.matchExt(rPath) { - return - } - - // validate regexp if present - if r.Regexp != nil { - // include trailing slash in regexp if present - start := len(r.Base) - if strings.HasSuffix(r.Base, "/") { - start-- - } - - matches := r.FindStringSubmatch(rPath[start:]) - switch len(matches) { - case 0: - // no match - return - default: - // set regexp match variables {1}, {2} ... - for i := 1; i < len(matches); i++ { - replacer.Set(fmt.Sprint(i), matches[i]) - } - } - } - - // validate rewrite conditions - for _, i := range r.Ifs { - if !i.True(req) { - return - } - } - - // if status is present, stop rewrite and return it. - if r.Status != 0 { - return RewriteStatus - } - - // attempt rewrite - return To(fs, req, r.To, replacer) -} - -// matchExt matches rPath against registered file extensions. -// Returns true if a match is found and false otherwise. -func (r *ComplexRule) matchExt(rPath string) bool { - f := filepath.Base(rPath) - ext := path.Ext(f) - if ext == "" { - ext = "/" - } - - mustUse := false - for _, v := range r.Exts { - use := true - if v[0] == '!' { - use = false - v = v[1:] - } - - if use { - mustUse = true - } - - if ext == v { - return use - } - } - - if mustUse { - return false - } - return true -} -*/ diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go index 50aa64ad2..2dc658738 100644 --- a/middleware/rewrite/rewrite_test.go +++ b/middleware/rewrite/rewrite_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -50,10 +51,10 @@ func TestRewrite(t *testing.T) { m.SetQuestion(tc.from, tc.fromT) m.Question[0].Qclass = tc.fromC - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) rw.ServeDNS(ctx, rec, m) - resp := rec.Msg() + resp := rec.Msg if resp.Question[0].Name != tc.to { t.Errorf("Test %d: Expected Name to be '%s' but was '%s'", i, tc.to, resp.Question[0].Name) } diff --git a/middleware/state_test.go b/middleware/state_test.go deleted file mode 100644 index ae4f1407a..000000000 --- a/middleware/state_test.go +++ /dev/null @@ -1,289 +0,0 @@ -package middleware - -import ( - "testing" - - "github.com/miekg/coredns/middleware/test" - - "github.com/miekg/dns" -) - -func TestStateDo(t *testing.T) { - st := testState() - - st.Do() - if st.do == 0 { - t.Fatalf("expected st.do to be set") - } -} - -func TestStateRemote(t *testing.T) { - st := testState() - if st.IP() != "10.240.0.1" { - t.Fatalf("wrong IP from state") - } - p, err := st.Port() - if err != nil { - t.Fatalf("failed to get Port from state") - } - if p != "40212" { - t.Fatalf("wrong port from state") - } -} - -func BenchmarkStateDo(b *testing.B) { - st := testState() - - for i := 0; i < b.N; i++ { - st.Do() - } -} - -func BenchmarkStateSize(b *testing.B) { - st := testState() - - for i := 0; i < b.N; i++ { - st.Size() - } -} - -func testState() State { - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) - m.SetEdns0(4097, true) - return State{W: &test.ResponseWriter{}, Req: m} -} - -/* -func TestHeader(t *testing.T) { - state := getContextOrFail(t) - - headerKey, headerVal := "Header1", "HeaderVal1" - state.Req.Header.Add(headerKey, headerVal) - - actualHeaderVal := state.Header(headerKey) - if actualHeaderVal != headerVal { - t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal) - } - - missingHeaderVal := state.Header("not-existing") - if missingHeaderVal != "" { - t.Errorf("Expected empty header value, found %s", missingHeaderVal) - } -} - -func TestIP(t *testing.T) { - state := getContextOrFail(t) - - tests := []struct { - inputRemoteAddr string - expectedIP string - }{ - // Test 0 - ipv4 with port - {"1.1.1.1:1111", "1.1.1.1"}, - // Test 1 - ipv4 without port - {"1.1.1.1", "1.1.1.1"}, - // Test 2 - ipv6 with port - {"[::1]:11", "::1"}, - // Test 3 - ipv6 without port and brackets - {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"}, - // Test 4 - ipv6 with zone and port - {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`}, - } - - for i, test := range tests { - testPrefix := getTestPrefix(i) - - state.Req.RemoteAddr = test.inputRemoteAddr - actualIP := state.IP() - - if actualIP != test.expectedIP { - t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP) - } - } -} - -func TestURL(t *testing.T) { - state := getContextOrFail(t) - - inputURL := "http://localhost" - state.Req.RequestURI = inputURL - - if inputURL != state.URI() { - t.Errorf("Expected url %s, found %s", inputURL, state.URI()) - } -} - -func TestHost(t *testing.T) { - tests := []struct { - input string - expectedHost string - shouldErr bool - }{ - { - input: "localhost:123", - expectedHost: "localhost", - shouldErr: false, - }, - { - input: "localhost", - expectedHost: "localhost", - shouldErr: false, - }, - { - input: "[::]", - expectedHost: "", - shouldErr: true, - }, - } - - for _, test := range tests { - testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr) - } -} - -func TestPort(t *testing.T) { - tests := []struct { - input string - expectedPort string - shouldErr bool - }{ - { - input: "localhost:123", - expectedPort: "123", - shouldErr: false, - }, - { - input: "localhost", - expectedPort: "80", // assuming 80 is the default port - shouldErr: false, - }, - { - input: ":8080", - expectedPort: "8080", - shouldErr: false, - }, - { - input: "[::]", - expectedPort: "", - shouldErr: true, - }, - } - - for _, test := range tests { - testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr) - } -} - -func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) { - state := getContextOrFail(t) - - state.Req.Host = input - var actualResult, testedObject string - var err error - - if isTestingHost { - actualResult, err = state.Host() - testedObject = "host" - } else { - actualResult, err = state.Port() - testedObject = "port" - } - - if shouldErr && err == nil { - t.Errorf("Expected error, found nil!") - return - } - - if !shouldErr && err != nil { - t.Errorf("Expected no error, found %s", err) - return - } - - if actualResult != expectedResult { - t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult) - } -} - -func TestPathMatches(t *testing.T) { - state := getContextOrFail(t) - - tests := []struct { - urlStr string - pattern string - shouldMatch bool - }{ - // Test 0 - { - urlStr: "http://localhost/", - pattern: "", - shouldMatch: true, - }, - // Test 1 - { - urlStr: "http://localhost", - pattern: "", - shouldMatch: true, - }, - // Test 1 - { - urlStr: "http://localhost/", - pattern: "/", - shouldMatch: true, - }, - // Test 3 - { - urlStr: "http://localhost/?param=val", - pattern: "/", - shouldMatch: true, - }, - // Test 4 - { - urlStr: "http://localhost/dir1/dir2", - pattern: "/dir2", - shouldMatch: false, - }, - // Test 5 - { - urlStr: "http://localhost/dir1/dir2", - pattern: "/dir1", - shouldMatch: true, - }, - // Test 6 - { - urlStr: "http://localhost:444/dir1/dir2", - pattern: "/dir1", - shouldMatch: true, - }, - } - - for i, test := range tests { - testPrefix := getTestPrefix(i) - var err error - state.Req.URL, err = url.Parse(test.urlStr) - if err != nil { - t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err) - } - - matches := state.PathMatches(test.pattern) - if matches != test.shouldMatch { - t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches) - } - } -} - -func initTestContext() (Context, error) { - body := bytes.NewBufferString("request body") - request, err := http.NewRequest("GET", "https://localhost", body) - if err != nil { - return Context{}, err - } - - return Context{Root: http.Dir(os.TempDir()), Req: request}, nil -} - - -func getTestPrefix(testN int) string { - return fmt.Sprintf("Test [%d]: ", testN) -} -*/ diff --git a/middleware/zone.go b/middleware/zone.go deleted file mode 100644 index 0f3cd0edc..000000000 --- a/middleware/zone.go +++ /dev/null @@ -1,27 +0,0 @@ -package middleware - -import "github.com/miekg/dns" - -type Zones []string - -// Matches checks is qname is a subdomain of any of the zones in z. The match -// will return the most specific zones that matches other. The empty string -// signals a not found condition. -func (z Zones) Matches(qname string) string { - zone := "" - for _, zname := range z { - if dns.IsSubDomain(zname, qname) { - if len(zname) > len(zone) { - zone = zname - } - } - } - return zone -} - -// FullyQualify fully qualifies all zones in z. -func (z Zones) FullyQualify() { - for i, _ := range z { - z[i] = dns.Fqdn(z[i]) - } -} diff --git a/middleware/state.go b/request/request.go similarity index 58% rename from middleware/state.go rename to request/request.go index 4299641ab..49d7312af 100644 --- a/middleware/state.go +++ b/request/request.go @@ -1,17 +1,16 @@ -package middleware +package request import ( "net" "strings" - "time" + + "github.com/miekg/coredns/middleware/pkg/edns" "github.com/miekg/dns" ) -// This file contains the state functions available for use in the middlewares. - -// State contains some connection state and is useful in middleware. -type State struct { +// Request contains some connection state and is useful in middleware. +type Request struct { Req *dns.Msg W dns.ResponseWriter @@ -24,38 +23,41 @@ type State struct { name string } -// Now returns the current timestamp in the specified format. -func (s *State) Now(format string) string { return time.Now().Format(format) } - -// NowDate returns the current date/time that can be used in other time functions. -func (s *State) NowDate() time.Time { return time.Now() } +// NewWithQuestion returns a new request based on the old, but with a new question +// section in the request. +func (r *Request) NewWithQuestion(name string, typ uint16) Request { + req1 := Request{W: r.W, Req: r.Req.Copy()} + req1.Req.Question[0] = dns.Question{Name: dns.Fqdn(name), Qtype: dns.ClassINET, Qclass: typ} + return req1 +} // IP gets the (remote) IP address of the client making the request. -func (s *State) IP() string { - ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String()) +func (r *Request) IP() string { + ip, _, err := net.SplitHostPort(r.W.RemoteAddr().String()) if err != nil { - return s.W.RemoteAddr().String() + return r.W.RemoteAddr().String() } return ip } // Post gets the (remote) Port of the client making the request. -func (s *State) Port() (string, error) { - _, port, err := net.SplitHostPort(s.W.RemoteAddr().String()) +func (r *Request) Port() string { + _, port, err := net.SplitHostPort(r.W.RemoteAddr().String()) if err != nil { - return "0", err + return "0" } - return port, nil + return port } // RemoteAddr returns the net.Addr of the client that sent the current request. -func (s *State) RemoteAddr() string { - return s.W.RemoteAddr().String() +func (r *Request) RemoteAddr() string { + return r.W.RemoteAddr().String() } // Proto gets the protocol used as the transport. This will be udp or tcp. -func (s *State) Proto() string { return Proto(s.W) } +func (r *Request) Proto() string { return Proto(r.W) } +// FIXME(miek): why not a method on Request // Proto gets the protocol used as the transport. This will be udp or tcp. func Proto(w dns.ResponseWriter) string { if _, ok := w.RemoteAddr().(*net.UDPAddr); ok { @@ -67,11 +69,10 @@ func Proto(w dns.ResponseWriter) string { return "udp" } -// Family returns the family of the transport. -// 1 for IPv4 and 2 for IPv6. -func (s *State) Family() int { +// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6. +func (r *Request) Family() int { var a net.IP - ip := s.W.RemoteAddr() + ip := r.W.RemoteAddr() if i, ok := ip.(*net.UDPAddr); ok { a = i.IP } @@ -86,48 +87,49 @@ func (s *State) Family() int { } // Do returns if the request has the DO (DNSSEC OK) bit set. -func (s *State) Do() bool { - if s.do != 0 { - return s.do == doTrue +func (r *Request) Do() bool { + if r.do != 0 { + return r.do == doTrue } - if o := s.Req.IsEdns0(); o != nil { + if o := r.Req.IsEdns0(); o != nil { if o.Do() { - s.do = doTrue + r.do = doTrue } else { - s.do = doFalse + r.do = doFalse } return o.Do() } - s.do = doFalse + r.do = doFalse return false } -// UDPSize returns if UDP buffer size advertised in the requests OPT record. +// Size returns if UDP buffer size advertised in the requests OPT record. // Or when the request was over TCP, we return the maximum allowed size of 64K. -func (s *State) Size() int { - if s.size != 0 { - return s.size +func (r *Request) Size() int { + if r.size != 0 { + return r.size } size := 0 - if o := s.Req.IsEdns0(); o != nil { + if o := r.Req.IsEdns0(); o != nil { if o.Do() == true { - s.do = doTrue + r.do = doTrue } else { - s.do = doFalse + r.do = doFalse } size = int(o.UDPSize()) } - size = edns0Size(s.Proto(), size) - s.size = size + // TODO(miek) move edns.Size to dnsutil? + size = edns.Size(r.Proto(), size) + r.size = size return size } -// SizeAndDo adds an OPT record that the reflects the intent from state. +// SizeAndDo adds an OPT record that the reflects the intent from request. // The returned bool indicated if an record was found and normalised. -func (s *State) SizeAndDo(m *dns.Msg) bool { - o := s.Req.IsEdns0() // TODO(miek): speed this up +func (r *Request) SizeAndDo(m *dns.Msg) bool { + o := r.Req.IsEdns0() // TODO(miek): speed this up if o == nil { return false } @@ -163,8 +165,8 @@ const ( // the TC bit will be set regardless of protocol, even TCP message will get the bit, the client // should then retry with pigeons. // TODO(referral). -func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { - size := s.Size() +func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) { + size := r.Size() l := reply.Len() if size >= l { return reply, ScrubIgnored @@ -173,7 +175,7 @@ func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { // If not delegation, drop additional section. reply.Extra = nil - s.SizeAndDo(reply) + r.SizeAndDo(reply) l = reply.Len() if size >= l { return reply, ScrubDone @@ -184,42 +186,42 @@ func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { } // Type returns the type of the question as a string. -func (s *State) Type() string { return dns.Type(s.Req.Question[0].Qtype).String() } +func (r *Request) Type() string { return dns.Type(r.Req.Question[0].Qtype).String() } // QType returns the type of the question as an uint16. -func (s *State) QType() uint16 { return s.Req.Question[0].Qtype } +func (r *Request) QType() uint16 { return r.Req.Question[0].Qtype } // Name returns the name of the question in the request. Note // this name will always have a closing dot and will be lower cased. After a call Name // the value will be cached. To clear this caching call Clear. -func (s *State) Name() string { - if s.name != "" { - return s.name +func (r *Request) Name() string { + if r.name != "" { + return r.name } - s.name = strings.ToLower(dns.Name(s.Req.Question[0].Name).String()) - return s.name + r.name = strings.ToLower(dns.Name(r.Req.Question[0].Name).String()) + return r.name } // QName returns the name of the question in the request. -func (s *State) QName() string { return dns.Name(s.Req.Question[0].Name).String() } +func (r *Request) QName() string { return dns.Name(r.Req.Question[0].Name).String() } // Class returns the class of the question in the request. -func (s *State) Class() string { return dns.Class(s.Req.Question[0].Qclass).String() } +func (r *Request) Class() string { return dns.Class(r.Req.Question[0].Qclass).String() } // QClass returns the class of the question in the request. -func (s *State) QClass() uint16 { return s.Req.Question[0].Qclass } +func (r *Request) QClass() uint16 { return r.Req.Question[0].Qclass } // ErrorMessage returns an error message suitable for sending // back to the client. -func (s *State) ErrorMessage(rcode int) *dns.Msg { +func (r *Request) ErrorMessage(rcode int) *dns.Msg { m := new(dns.Msg) - m.SetRcode(s.Req, rcode) + m.SetRcode(r.Req, rcode) return m } -// Clear clears all caching from State s. -func (s *State) Clear() { - s.name = "" +// Clear clears all caching from Request s. +func (r *Request) Clear() { + r.name = "" } const ( diff --git a/request/request_test.go b/request/request_test.go new file mode 100644 index 000000000..e49e9833f --- /dev/null +++ b/request/request_test.go @@ -0,0 +1,55 @@ +package request + +import ( + "testing" + + "github.com/miekg/coredns/middleware/test" + + "github.com/miekg/dns" +) + +func TestRequestDo(t *testing.T) { + st := testRequest() + + st.Do() + if st.do == 0 { + t.Fatalf("Expected st.do to be set") + } +} + +func TestRequestRemote(t *testing.T) { + st := testRequest() + if st.IP() != "10.240.0.1" { + t.Fatalf("Wrong IP from request") + } + p := st.Port() + if p == "" { + t.Fatalf("Failed to get Port from request") + } + if p != "40212" { + t.Fatalf("Wrong port from request") + } +} + +func BenchmarkRequestDo(b *testing.B) { + st := testRequest() + + for i := 0; i < b.N; i++ { + st.Do() + } +} + +func BenchmarkRequestSize(b *testing.B) { + st := testRequest() + + for i := 0; i < b.N; i++ { + st.Size() + } +} + +func testRequest() Request { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.SetEdns0(4097, true) + return Request{W: &test.ResponseWriter{}, Req: m} +} diff --git a/test/etcd_test.go b/test/etcd_test.go index 6645e9931..6f3c0b39a 100644 --- a/test/etcd_test.go +++ b/test/etcd_test.go @@ -9,11 +9,11 @@ import ( "testing" "time" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd" "github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" etcdc "github.com/coreos/etcd/client" "github.com/miekg/dns" @@ -67,7 +67,7 @@ func TestEtcdStubAndProxyLookup(t *testing.T) { } p := proxy.New([]string{udp}) // use udp port from the server - state := middleware.State{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} resp, err := p.Lookup(state, "example.com.", dns.TypeA) if err != nil { t.Error("Expected to receive reply, but didn't") diff --git a/test/proxy_test.go b/test/proxy_test.go index 96a2c4c0d..e946f4354 100644 --- a/test/proxy_test.go +++ b/test/proxy_test.go @@ -5,9 +5,9 @@ import ( "log" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -46,7 +46,7 @@ func TestLookupProxy(t *testing.T) { log.SetOutput(ioutil.Discard) p := proxy.New([]string{udp}) - state := middleware.State{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} resp, err := p.Lookup(state, "example.org.", dns.TypeA) if err != nil { t.Fatal("Expected to receive reply, but didn't") diff --git a/test/server_test.go b/test/server_test.go index a03285bf3..903922a04 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -17,7 +17,7 @@ func TestProxyToChaosServer(t *testing.T) { t.Fatalf("could not get CoreDNS serving instance: %s", err) } - udpChaos, tcpChaos := CoreDNSServerPorts(chaos, 0) + udpChaos, _ := CoreDNSServerPorts(chaos, 0) defer chaos.Stop() corefileProxy := `.:0 { @@ -32,24 +32,23 @@ func TestProxyToChaosServer(t *testing.T) { udp, _ := CoreDNSServerPorts(proxy, 0) defer proxy.Stop() - chaosTest(t, udpChaos, "udp") - chaosTest(t, tcpChaos, "tcp") + chaosTest(t, udpChaos) - chaosTest(t, udp, "udp") + chaosTest(t, udp) // chaosTest(t, tcp, "tcp"), commented out because we use the original transport to reach the // proxy and we only forward to the udp port. } -func chaosTest(t *testing.T, server, net string) { +func chaosTest(t *testing.T, server string) { m := Msg("version.bind.", dns.TypeTXT, nil) m.Question[0].Qclass = dns.ClassCHAOS - r, err := Exchange(m, server, net) + r, err := dns.Exchange(m, server) if err != nil { t.Fatalf("Could not send message: %s", err) } if r.Rcode != dns.RcodeSuccess || len(r.Answer) == 0 { - t.Fatalf("Expected successful reply on %s, got %s", net, dns.RcodeToString[r.Rcode]) + t.Fatalf("Expected successful reply, got %s", dns.RcodeToString[r.Rcode]) } if r.Answer[0].String() != `version.bind. 0 CH TXT "CoreDNS-001"` { t.Fatalf("Expected version.bind. reply, got %s", r.Answer[0].String()) diff --git a/test/tests.go b/test/tests.go index 0f9d12bae..fb2853e9a 100644 --- a/test/tests.go +++ b/test/tests.go @@ -1,10 +1,6 @@ package test -import ( - "github.com/miekg/coredns/middleware" - - "github.com/miekg/dns" -) +import "github.com/miekg/dns" func Msg(zone string, typ uint16, o *dns.OPT) *dns.Msg { m := new(dns.Msg) @@ -14,9 +10,3 @@ func Msg(zone string, typ uint16, o *dns.OPT) *dns.Msg { } return m } - -func Exchange(m *dns.Msg, server, net string) (*dns.Msg, error) { - c := new(dns.Client) - c.Net = net - return middleware.Exchange(c, m, server) -}