diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go index c29406f79..77a28f477 100644 --- a/core/dnsserver/server.go +++ b/core/dnsserver/server.go @@ -7,7 +7,8 @@ import ( "sync" "time" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/edns" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -163,7 +164,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } }() - if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once. + if m, err := edns.Version(r); err != nil { // Wrong EDNS version, return at once. w.WriteMsg(m) return } @@ -214,10 +215,11 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // DefaultErrorFunc responds to an DNS request with an error. func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} answer := new(dns.Msg) answer.SetRcode(r, rcode) + state.SizeAndDo(answer) w.WriteMsg(answer) diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index fd804d445..bd29a815c 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -6,6 +6,7 @@ import ( "time" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/response" "github.com/miekg/dns" gcache "github.com/patrickmn/go-cache" @@ -23,7 +24,7 @@ func NewCache(ttl int, zones []string, next middleware.Handler) Cache { return Cache{Next: next, Zones: zones, cache: gcache.New(defaultDuration, purgeDuration), cap: time.Duration(ttl) * time.Second} } -func cacheKey(m *dns.Msg, t middleware.MsgType, do bool) string { +func cacheKey(m *dns.Msg, t response.Type, do bool) string { if m.Truncated { return "" } @@ -31,15 +32,15 @@ func cacheKey(m *dns.Msg, t middleware.MsgType, do bool) string { qtype := m.Question[0].Qtype qname := strings.ToLower(m.Question[0].Name) switch t { - case middleware.Success: + case response.Success: fallthrough - case middleware.Delegation: + case response.Delegation: return successKey(qname, qtype, do) - case middleware.NameError: + case response.NameError: return nameErrorKey(qname, do) - case middleware.NoData: + case response.NoData: return noDataKey(qname, qtype, do) - case middleware.OtherError: + case response.OtherError: return "" } return "" @@ -57,7 +58,7 @@ func NewCachingResponseWriter(w dns.ResponseWriter, cache *gcache.Cache, cap tim func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error { do := false - mt, opt := middleware.Classify(res) + mt, opt := response.Classify(res) if opt != nil { do = opt.Do() } @@ -72,7 +73,7 @@ func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error { return c.ResponseWriter.WriteMsg(res) } -func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt middleware.MsgType) { +func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt response.Type) { if key == "" { log.Printf("[ERROR] Caching called with empty cache key") return @@ -80,21 +81,21 @@ func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt middleware.MsgTyp duration := c.cap switch mt { - case middleware.Success, middleware.Delegation: + case response.Success, response.Delegation: if c.cap == 0 { duration = minTtl(m.Answer, mt) } i := newItem(m, duration) c.cache.Set(key, i, duration) - case middleware.NameError, middleware.NoData: + case response.NameError, response.NoData: if c.cap == 0 { duration = minTtl(m.Ns, mt) } i := newItem(m, duration) c.cache.Set(key, i, duration) - case middleware.OtherError: + case response.OtherError: // don't cache these default: log.Printf("[WARNING] Caching called with unknown middleware MsgType: %d", mt) @@ -112,19 +113,19 @@ func (c *CachingResponseWriter) Hijack() { return } -func minTtl(rrs []dns.RR, mt middleware.MsgType) time.Duration { - if mt != middleware.Success && mt != middleware.NameError && mt != middleware.NoData { +func minTtl(rrs []dns.RR, mt response.Type) time.Duration { + if mt != response.Success && mt != response.NameError && mt != response.NoData { return 0 } minTtl := maxTtl for _, r := range rrs { switch mt { - case middleware.NameError, middleware.NoData: + case response.NameError, response.NoData: if r.Header().Rrtype == dns.TypeSOA { return time.Duration(r.(*dns.SOA).Minttl) * time.Second } - case middleware.Success, middleware.Delegation: + case response.Success, response.Delegation: if r.Header().Ttl < minTtl { minTtl = r.Header().Ttl } diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 452831082..dc44fa8a3 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/response" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -78,7 +79,7 @@ func TestCache(t *testing.T) { m = cacheMsg(m, tc) do := tc.in.Do - mt, _ := middleware.Classify(m) + mt, _ := response.Classify(m) key := cacheKey(m, mt, do) crr.set(m, key, mt) diff --git a/middleware/cache/handler.go b/middleware/cache/handler.go index b891d7278..045c8ab1d 100644 --- a/middleware/cache/handler.go +++ b/middleware/cache/handler.go @@ -2,6 +2,7 @@ package cache import ( "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" @@ -10,7 +11,7 @@ import ( // ServeDNS implements the middleware.Handler interface. func (c Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} qname := state.Name() qtype := state.QType() diff --git a/middleware/chaos/chaos.go b/middleware/chaos/chaos.go index 506298de4..8b6df3e84 100644 --- a/middleware/chaos/chaos.go +++ b/middleware/chaos/chaos.go @@ -4,6 +4,7 @@ import ( "os" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -18,7 +19,7 @@ type Chaos struct { } func (c Chaos) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassCHAOS || state.QType() != dns.TypeTXT { return c.Next.ServeDNS(ctx, w, r) } diff --git a/middleware/chaos/chaos_test.go b/middleware/chaos/chaos_test.go index f669eaa7c..6a3261754 100644 --- a/middleware/chaos/chaos_test.go +++ b/middleware/chaos/chaos_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -58,7 +59,7 @@ func TestChaos(t *testing.T) { req.Question[0].Qclass = dns.ClassCHAOS em.Next = tc.next - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) code, err := em.ServeDNS(ctx, rec, req) if err != tc.expectedErr { @@ -68,7 +69,7 @@ func TestChaos(t *testing.T) { t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) } if tc.expectedReply != "" { - answer := rec.Msg().Answer[0].(*dns.TXT).Txt[0] + answer := rec.Msg.Answer[0].(*dns.TXT).Txt[0] if answer != tc.expectedReply { t.Errorf("Test %d: Expected answer %s, but got %s", i, tc.expectedReply, answer) } diff --git a/middleware/dnssec/black_lies_test.go b/middleware/dnssec/black_lies_test.go index 951e8952e..092c36f56 100644 --- a/middleware/dnssec/black_lies_test.go +++ b/middleware/dnssec/black_lies_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -16,7 +16,7 @@ func TestZoneSigningBlackLies(t *testing.T) { defer rm2() m := testNxdomainMsg() - state := middleware.State{Req: m} + state := request.Request{Req: m} m = d.Sign(state, "miek.nl.", time.Now().UTC()) if !section(m.Ns, 2) { t.Errorf("authority section should have 2 sig") diff --git a/middleware/dnssec/cache_test.go b/middleware/dnssec/cache_test.go index 3062f99b0..0f069b6a8 100644 --- a/middleware/dnssec/cache_test.go +++ b/middleware/dnssec/cache_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" ) func TestCacheSet(t *testing.T) { @@ -20,7 +20,7 @@ func TestCacheSet(t *testing.T) { } m := testMsg() - state := middleware.State{Req: m} + state := request.Request{Req: m} k := key(m.Answer) // calculate *before* we add the sig d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil) m = d.Sign(state, "miek.nl.", time.Now().UTC()) diff --git a/middleware/dnssec/dnskey.go b/middleware/dnssec/dnskey.go index 9ae437c54..af345f906 100644 --- a/middleware/dnssec/dnskey.go +++ b/middleware/dnssec/dnskey.go @@ -8,7 +8,7 @@ import ( "os" "time" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -50,7 +50,7 @@ func ParseKeyFile(pubFile, privFile string) (*DNSKEY, error) { } // getDNSKEY returns the correct DNSKEY to the client. Signatures are added when do is true. -func (d Dnssec) getDNSKEY(state middleware.State, zone string, do bool) *dns.Msg { +func (d Dnssec) getDNSKEY(state request.Request, zone string, do bool) *dns.Msg { keys := make([]dns.RR, len(d.keys)) for i, k := range d.keys { keys[i] = dns.Copy(k.K) diff --git a/middleware/dnssec/dnssec.go b/middleware/dnssec/dnssec.go index f517bfe2c..ea914c0ee 100644 --- a/middleware/dnssec/dnssec.go +++ b/middleware/dnssec/dnssec.go @@ -4,7 +4,9 @@ import ( "time" "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/singleflight" + "github.com/miekg/coredns/middleware/pkg/response" + "github.com/miekg/coredns/middleware/pkg/singleflight" + "github.com/miekg/coredns/request" "github.com/miekg/dns" gcache "github.com/patrickmn/go-cache" @@ -28,20 +30,21 @@ func New(zones []string, keys []*DNSKEY, next middleware.Handler) Dnssec { } } -// Sign signs the message m. it takes care of negative or nodata responses. It +// Sign signs the message in state. it takes care of negative or nodata responses. It // uses NSEC black lies for authenticated denial of existence. Signatures // creates will be cached for a short while. By default we sign for 8 days, // starting 3 hours ago. -func (d Dnssec) Sign(state middleware.State, zone string, now time.Time) *dns.Msg { +func (d Dnssec) Sign(state request.Request, zone string, now time.Time) *dns.Msg { req := state.Req - mt, _ := middleware.Classify(req) // TODO(miek): need opt record here? - if mt == middleware.Delegation { + + mt, _ := response.Classify(req) // TODO(miek): need opt record here? + if mt == response.Delegation { return req } incep, expir := incepExpir(now) - if mt == middleware.NameError { + if mt == response.NameError { if req.Ns[0].Header().Rrtype != dns.TypeSOA || len(req.Ns) > 1 { return req } diff --git a/middleware/dnssec/dnssec_test.go b/middleware/dnssec/dnssec_test.go index 10f731325..48f13e935 100644 --- a/middleware/dnssec/dnssec_test.go +++ b/middleware/dnssec/dnssec_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -16,7 +16,7 @@ func TestZoneSigning(t *testing.T) { defer rm2() m := testMsg() - state := middleware.State{Req: m} + state := request.Request{Req: m} m = d.Sign(state, "miek.nl.", time.Now().UTC()) if !section(m.Answer, 1) { @@ -44,7 +44,7 @@ func TestZoneSigningDouble(t *testing.T) { d.keys = append(d.keys, key1) m := testMsg() - state := middleware.State{Req: m} + state := request.Request{Req: m} m = d.Sign(state, "miek.nl.", time.Now().UTC()) if !section(m.Answer, 2) { t.Errorf("answer section should have 1 sig") @@ -68,7 +68,7 @@ func TestSigningDifferentZone(t *testing.T) { } m := testMsgEx() - state := middleware.State{Req: m} + state := request.Request{Req: m} d := New([]string{"example.org."}, []*DNSKEY{key}, nil) m = d.Sign(state, "example.org.", time.Now().UTC()) if !section(m.Answer, 1) { @@ -86,7 +86,7 @@ func TestSigningCname(t *testing.T) { defer rm2() m := testMsgCname() - state := middleware.State{Req: m} + state := request.Request{Req: m} m = d.Sign(state, "miek.nl.", time.Now().UTC()) if !section(m.Answer, 1) { t.Errorf("answer section should have 1 sig") @@ -100,7 +100,7 @@ func TestZoneSigningDelegation(t *testing.T) { defer rm2() m := testDelegationMsg() - state := middleware.State{Req: m} + state := request.Request{Req: m} m = d.Sign(state, "miek.nl.", time.Now().UTC()) if !section(m.Ns, 0) { t.Errorf("authority section should have 0 sig") diff --git a/middleware/dnssec/handler.go b/middleware/dnssec/handler.go index 5daf3d322..609150028 100644 --- a/middleware/dnssec/handler.go +++ b/middleware/dnssec/handler.go @@ -2,6 +2,7 @@ package dnssec import ( "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" @@ -10,7 +11,7 @@ import ( // ServeDNS implements the middleware.Handler interface. func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} do := state.Do() qname := state.Name() diff --git a/middleware/dnssec/handler_test.go b/middleware/dnssec/handler_test.go index f7cb7e680..a490c3744 100644 --- a/middleware/dnssec/handler_test.go +++ b/middleware/dnssec/handler_test.go @@ -5,8 +5,8 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/file" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -83,14 +83,14 @@ func TestLookupZone(t *testing.T) { for _, tc := range dnsTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := dh.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) @@ -121,14 +121,14 @@ func TestLookupDNSKEY(t *testing.T) { for _, tc := range dnssecTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := dh.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg if !resp.Authoritative { t.Errorf("Authoritative Answer should be true, got false") } diff --git a/middleware/dnssec/responsewriter.go b/middleware/dnssec/responsewriter.go index 2a7cbb972..0032fa7ba 100644 --- a/middleware/dnssec/responsewriter.go +++ b/middleware/dnssec/responsewriter.go @@ -5,6 +5,8 @@ import ( "time" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" + "github.com/miekg/dns" ) @@ -20,7 +22,7 @@ func NewDnssecResponseWriter(w dns.ResponseWriter, d Dnssec) *DnssecResponseWrit func (d *DnssecResponseWriter) WriteMsg(res *dns.Msg) error { // By definition we should sign anything that comes back, we should still figure out for // which zone it should be. - state := middleware.State{W: d.ResponseWriter, Req: res} + state := request.Request{W: d.ResponseWriter, Req: res} qname := state.Name() zone := middleware.Zones(d.d.zones).Matches(qname) diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go index 8936b8945..ef178d1a3 100644 --- a/middleware/errors/errors.go +++ b/middleware/errors/errors.go @@ -9,6 +9,8 @@ import ( "time" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/roller" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -19,7 +21,7 @@ type ErrorHandler struct { Next middleware.Handler LogFile string Log *log.Logger - LogRoller *middleware.LogRoller + LogRoller *roller.LogRoller Debug bool // if true, errors are written out to client rather than to a log } @@ -29,7 +31,7 @@ func (h ErrorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns rcode, err := h.Next.ServeDNS(ctx, w, r) if err != nil { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, state.Name(), state.Type(), err) if h.Debug { @@ -53,7 +55,7 @@ func (h ErrorHandler) recovery(ctx context.Context, w dns.ResponseWriter, r *dns return } - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} // Obtain source of panic // From: https://gist.github.com/swdunlop/9629168 var name, file string // function name, file name diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go index 6885aae0f..4c8f56fa2 100644 --- a/middleware/errors/errors_test.go +++ b/middleware/errors/errors_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -47,7 +48,7 @@ func TestErrors(t *testing.T) { for i, tc := range tests { em.Next = tc.next buf.Reset() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) code, err := em.ServeDNS(ctx, rec, req) if err != tc.expectedErr { @@ -78,7 +79,7 @@ func TestVisibleErrorWithPanic(t *testing.T) { req := new(dns.Msg) req.SetQuestion("example.org.", dns.TypeA) - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) code, err := eh.ServeDNS(ctx, rec, req) if code != 0 { diff --git a/middleware/errors/setup.go b/middleware/errors/setup.go index e1c77373d..5c7c1016c 100644 --- a/middleware/errors/setup.go +++ b/middleware/errors/setup.go @@ -6,7 +6,7 @@ import ( "os" "github.com/miekg/coredns/core/dnsserver" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/roller" "github.com/hashicorp/go-syslog" "github.com/mholt/caddy" @@ -93,7 +93,7 @@ func errorsParse(c *caddy.Controller) (ErrorHandler, error) { if c.NextArg() { if c.Val() == "{" { c.IncrNest() - logRoller, err := middleware.ParseRoller(c) + logRoller, err := roller.Parse(c) if err != nil { return hadBlock, err } diff --git a/middleware/errors/setup_test.go b/middleware/errors/setup_test.go index 6e5a85d08..c1dbf7267 100644 --- a/middleware/errors/setup_test.go +++ b/middleware/errors/setup_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/mholt/caddy" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/roller" ) func TestErrorsParse(t *testing.T) { @@ -29,7 +29,7 @@ func TestErrorsParse(t *testing.T) { }}, {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, ErrorHandler{ LogFile: "errors.txt", - LogRoller: &middleware.LogRoller{ + LogRoller: &roller.LogRoller{ MaxSize: 2, MaxAge: 10, MaxBackups: 3, @@ -43,7 +43,7 @@ func TestErrorsParse(t *testing.T) { } }`, false, ErrorHandler{ LogFile: "errors.txt", - LogRoller: &middleware.LogRoller{ + LogRoller: &roller.LogRoller{ MaxSize: 3, MaxAge: 11, MaxBackups: 5, diff --git a/middleware/etcd/cname_test.go b/middleware/etcd/cname_test.go index ee341b7b6..fdd0f50e5 100644 --- a/middleware/etcd/cname_test.go +++ b/middleware/etcd/cname_test.go @@ -7,8 +7,8 @@ package etcd import ( "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -25,14 +25,14 @@ func TestCnameLookup(t *testing.T) { for _, tc := range dnsTestCasesCname { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg if !test.Header(t, tc, resp) { t.Logf("%v\n", resp) continue diff --git a/middleware/etcd/debug_test.go b/middleware/etcd/debug_test.go index 82de9fe1f..930ceb8ce 100644 --- a/middleware/etcd/debug_test.go +++ b/middleware/etcd/debug_test.go @@ -7,8 +7,8 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -41,14 +41,14 @@ func TestDebugLookup(t *testing.T) { for _, tc := range dnsTestCasesDebug { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) continue } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) @@ -79,14 +79,14 @@ func TestDebugLookupFalse(t *testing.T) { for _, tc := range dnsTestCasesDebugFalse { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) continue } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/etcd.go b/middleware/etcd/etcd.go index 87a82e8cb..4b5b424f4 100644 --- a/middleware/etcd/etcd.go +++ b/middleware/etcd/etcd.go @@ -9,8 +9,8 @@ import ( "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/singleflight" "github.com/miekg/coredns/middleware/proxy" - "github.com/miekg/coredns/singleflight" etcdc "github.com/coreos/etcd/client" "golang.org/x/net/context" diff --git a/middleware/etcd/group_test.go b/middleware/etcd/group_test.go index 7a2808d45..abf777982 100644 --- a/middleware/etcd/group_test.go +++ b/middleware/etcd/group_test.go @@ -6,8 +6,8 @@ import ( "sort" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -23,14 +23,14 @@ func TestGroupLookup(t *testing.T) { for _, tc := range dnsTestCasesGroup { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) continue } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/handler.go b/middleware/etcd/handler.go index 132dba370..40f9523f0 100644 --- a/middleware/etcd/handler.go +++ b/middleware/etcd/handler.go @@ -5,6 +5,7 @@ import ( "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -12,7 +13,7 @@ import ( func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { opt := Options{} - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassINET { return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET") } @@ -115,7 +116,7 @@ func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( } // Err write an error response to the client. -func (e *Etcd) Err(zone string, rcode int, state middleware.State, debug []msg.Service, err error, opt Options) (int, error) { +func (e *Etcd) Err(zone string, rcode int, state request.Request, debug []msg.Service, err error, opt Options) (int, error) { m := new(dns.Msg) m.SetRcode(state.Req, rcode) m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true diff --git a/middleware/etcd/lookup.go b/middleware/etcd/lookup.go index 488ca4f26..b13c257e6 100644 --- a/middleware/etcd/lookup.go +++ b/middleware/etcd/lookup.go @@ -8,6 +8,8 @@ import ( "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsutil" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -16,7 +18,7 @@ type Options struct { Debug string } -func (e Etcd) records(state middleware.State, exact bool, opt Options) (services, debug []msg.Service, err error) { +func (e Etcd) records(state request.Request, exact bool, opt Options) (services, debug []msg.Service, err error) { services, err = e.Records(state.Name(), exact) if err != nil { return @@ -28,7 +30,7 @@ func (e Etcd) records(state middleware.State, exact bool, opt Options) (services return } -func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) { +func (e Etcd) A(zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, false, opt) if err != nil { return nil, debug, err @@ -49,11 +51,11 @@ func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, o // don't add it, and just continue continue } - if isDuplicateCNAME(newRecord, previousRecords) { + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { continue } - state1 := copyState(state, serv.Host, state.QType()) + state1 := state.NewWithQuestion(serv.Host, state.QType()) nextRecords, nextDebug, err := e.A(zone, state1, append(previousRecords, newRecord), opt) if err == nil { @@ -90,7 +92,7 @@ func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, o return records, debug, nil } -func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) { +func (e Etcd) AAAA(zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, false, opt) if err != nil { return nil, debug, err @@ -111,11 +113,11 @@ func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR // don't add it, and just continue continue } - if isDuplicateCNAME(newRecord, previousRecords) { + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { continue } - state1 := copyState(state, serv.Host, state.QType()) + state1 := state.NewWithQuestion(serv.Host, state.QType()) nextRecords, nextDebug, err := e.AAAA(zone, state1, append(previousRecords, newRecord), opt) if err == nil { @@ -155,7 +157,7 @@ func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR // SRV returns SRV records from etcd. // If the Target is not a name but an IP address, a name is created on the fly. -func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { +func (e Etcd) SRV(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, false, opt) if err != nil { return nil, nil, nil, err @@ -220,7 +222,7 @@ func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, ex } // Internal name, we should have some info on them, either v4 or v6 // Clients expect a complete answer, because we are a recursor in their view. - state1 := copyState(state, srv.Target, dns.TypeA) + state1 := state.NewWithQuestion(srv.Target, dns.TypeA) addr, debugAddr, e1 := e.A(zone, state1, nil, opt) if e1 == nil { extra = append(extra, addr...) @@ -246,7 +248,7 @@ func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, ex // MX returns MX records from etcd. // If the Target is not a name but an IP address, a name is created on the fly. -func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { +func (e Etcd) MX(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, false, opt) if err != nil { return nil, nil, debug, err @@ -291,7 +293,7 @@ func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, ext break } // Internal name - state1 := copyState(state, mx.Mx, dns.TypeA) + state1 := state.NewWithQuestion(mx.Mx, dns.TypeA) addr, debugAddr, e1 := e.A(zone, state1, nil, opt) if e1 == nil { extra = append(extra, addr...) @@ -311,7 +313,7 @@ func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, ext return records, extra, debug, nil } -func (e Etcd) CNAME(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) { +func (e Etcd) CNAME(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, true, opt) if err != nil { return nil, debug, err @@ -327,7 +329,7 @@ func (e Etcd) CNAME(zone string, state middleware.State, opt Options) (records [ } // PTR returns the PTR records, only services that have a domain name as host are included. -func (e Etcd) PTR(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) { +func (e Etcd) PTR(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, true, opt) if err != nil { return nil, debug, err @@ -341,7 +343,7 @@ func (e Etcd) PTR(zone string, state middleware.State, opt Options) (records []d return records, debug, nil } -func (e Etcd) TXT(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) { +func (e Etcd) TXT(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) { services, debug, err := e.records(state, false, opt) if err != nil { return nil, debug, err @@ -356,7 +358,7 @@ func (e Etcd) TXT(zone string, state middleware.State, opt Options) (records []d return records, debug, nil } -func (e Etcd) NS(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { +func (e Etcd) NS(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { // NS record for this zone live in a special place, ns.dns.. Fake our lookup. // only a tad bit fishy... old := state.QName() @@ -389,7 +391,7 @@ func (e Etcd) NS(zone string, state middleware.State, opt Options) (records, ext } // SOA Record returns a SOA record. -func (e Etcd) SOA(zone string, state middleware.State, opt Options) ([]dns.RR, []msg.Service, error) { +func (e Etcd) SOA(zone string, state request.Request, opt Options) ([]dns.RR, []msg.Service, error) { header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: 300, Class: dns.ClassINET} soa := &dns.SOA{Hdr: header, @@ -404,21 +406,3 @@ func (e Etcd) SOA(zone string, state middleware.State, opt Options) ([]dns.RR, [ // TODO(miek): fake some msg.Service here when returning. return []dns.RR{soa}, nil, nil } - -func isDuplicateCNAME(r *dns.CNAME, records []dns.RR) bool { - for _, rec := range records { - if v, ok := rec.(*dns.CNAME); ok { - if v.Target == r.Target { - return true - } - } - } - return false -} - -// TODO(miek): Move to middleware? -func copyState(state middleware.State, target string, typ uint16) middleware.State { - state1 := middleware.State{W: state.W, Req: state.Req.Copy()} - state1.Req.Question[0] = dns.Question{Name: dns.Fqdn(target), Qclass: dns.ClassINET, Qtype: typ} - return state1 -} diff --git a/middleware/etcd/multi_test.go b/middleware/etcd/multi_test.go index f4b59f50b..d4e3493a7 100644 --- a/middleware/etcd/multi_test.go +++ b/middleware/etcd/multi_test.go @@ -6,8 +6,8 @@ import ( "sort" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -25,14 +25,14 @@ func TestMultiLookup(t *testing.T) { for _, tc := range dnsTestCasesMulti { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/other_test.go b/middleware/etcd/other_test.go index ff37d27d2..2ed622fe0 100644 --- a/middleware/etcd/other_test.go +++ b/middleware/etcd/other_test.go @@ -10,8 +10,8 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -27,14 +27,14 @@ func TestOtherLookup(t *testing.T) { for _, tc := range dnsTestCasesOther { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) continue } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/proxy_lookup_test.go b/middleware/etcd/proxy_lookup_test.go index 5e0999fb0..8b4697e25 100644 --- a/middleware/etcd/proxy_lookup_test.go +++ b/middleware/etcd/proxy_lookup_test.go @@ -6,8 +6,8 @@ import ( "sort" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/test" @@ -27,14 +27,14 @@ func TestProxyLookupFailDebug(t *testing.T) { for _, tc := range dnsTestCasesProxy { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) continue } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/setup.go b/middleware/etcd/setup.go index dc1dddb0e..58ca8286f 100644 --- a/middleware/etcd/setup.go +++ b/middleware/etcd/setup.go @@ -10,8 +10,8 @@ import ( "github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/singleflight" "github.com/miekg/coredns/middleware/proxy" - "github.com/miekg/coredns/singleflight" etcdc "github.com/coreos/etcd/client" "github.com/mholt/caddy" @@ -70,7 +70,7 @@ func etcdParse(c *caddy.Controller) (*Etcd, bool, error) { etc.Zones = make([]string, len(c.ServerBlockKeys)) copy(etc.Zones, c.ServerBlockKeys) } - middleware.Zones(etc.Zones).FullyQualify() + middleware.Zones(etc.Zones).Normalize() if c.NextBlock() { // TODO(miek): 2 switches? switch c.Val() { diff --git a/middleware/etcd/setup_test.go b/middleware/etcd/setup_test.go index b522345d2..39ad3d3a1 100644 --- a/middleware/etcd/setup_test.go +++ b/middleware/etcd/setup_test.go @@ -8,11 +8,11 @@ import ( "testing" "time" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" + "github.com/miekg/coredns/middleware/pkg/singleflight" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/test" - "github.com/miekg/coredns/singleflight" etcdc "github.com/coreos/etcd/client" "github.com/miekg/dns" @@ -65,14 +65,14 @@ func TestLookup(t *testing.T) { for _, tc := range dnsTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil { t.Errorf("expected no error, got: %v for %s %s\n", err, m.Question[0].Name, dns.Type(m.Question[0].Qtype)) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/etcd/stub_handler.go b/middleware/etcd/stub_handler.go index 9d8778219..10a3d2198 100644 --- a/middleware/etcd/stub_handler.go +++ b/middleware/etcd/stub_handler.go @@ -4,7 +4,7 @@ import ( "errors" "log" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -27,7 +27,7 @@ func (s Stub) ServeDNS(ctx context.Context, w dns.ResponseWriter, req *dns.Msg) return dns.RcodeServerFailure, nil } - state := middleware.State{W: w, Req: req} + state := request.Request{W: w, Req: req} m, e := proxy.Forward(state) if e != nil { return dns.RcodeServerFailure, e diff --git a/middleware/etcd/stub_test.go b/middleware/etcd/stub_test.go index b5a101dad..72a8fb7a0 100644 --- a/middleware/etcd/stub_test.go +++ b/middleware/etcd/stub_test.go @@ -8,8 +8,8 @@ import ( "strconv" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd/msg" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -53,7 +53,7 @@ func TestStubLookup(t *testing.T) { for _, tc := range dnsTestCasesStub { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := etc.ServeDNS(ctxt, rec, m) if err != nil && m.Question[0].Name == "example.org." { // This is OK, we expect this backend to *not* work. @@ -62,12 +62,11 @@ func TestStubLookup(t *testing.T) { if err != nil { t.Errorf("expected no error, got %v for %s\n", err, m.Question[0].Name) } - resp := rec.Msg() + resp := rec.Msg if resp == nil { // etcd not running? continue } - sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/exchange.go b/middleware/exchange.go deleted file mode 100644 index 783d06e26..000000000 --- a/middleware/exchange.go +++ /dev/null @@ -1,8 +0,0 @@ -package middleware - -import "github.com/miekg/dns" - -func Exchange(c *dns.Client, m *dns.Msg, server string) (*dns.Msg, error) { - r, _, err := c.Exchange(m, server) - return r, err -} diff --git a/middleware/file/delegation_test.go b/middleware/file/delegation_test.go index 727febade..30a6e430c 100644 --- a/middleware/file/delegation_test.go +++ b/middleware/file/delegation_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -57,14 +57,14 @@ func TestLookupDelegation(t *testing.T) { for _, tc := range delegationTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/file/dnssec_test.go b/middleware/file/dnssec_test.go index 2d76447b3..434ebed8f 100644 --- a/middleware/file/dnssec_test.go +++ b/middleware/file/dnssec_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -116,14 +116,14 @@ func TestLookupDNSSEC(t *testing.T) { for _, tc := range dnssecTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) @@ -154,7 +154,7 @@ func BenchmarkLookupDNSSEC(b *testing.B) { fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} ctx := context.TODO() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) tc := test.Case{ Qname: "b.miek.nl.", Qtype: dns.TypeA, Do: true, diff --git a/middleware/file/ent_test.go b/middleware/file/ent_test.go index 735ec67fe..324c6cfb6 100644 --- a/middleware/file/ent_test.go +++ b/middleware/file/ent_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -43,14 +43,14 @@ func TestLookupENT(t *testing.T) { for _, tc := range entTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/file/file.go b/middleware/file/file.go index a99a64d6f..d2ced17e4 100644 --- a/middleware/file/file.go +++ b/middleware/file/file.go @@ -6,6 +6,7 @@ import ( "log" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -24,7 +25,7 @@ type ( ) func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassINET { return dns.RcodeServerFailure, errors.New("can only deal with ClassINET") diff --git a/middleware/file/lookup_test.go b/middleware/file/lookup_test.go index 90f525e3c..d8efd6ea6 100644 --- a/middleware/file/lookup_test.go +++ b/middleware/file/lookup_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -87,14 +87,14 @@ func TestLookup(t *testing.T) { for _, tc := range dnsTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) @@ -122,7 +122,7 @@ func TestLookupNil(t *testing.T) { ctx := context.TODO() m := dnsTestCases[0].Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) fm.ServeDNS(ctx, rec, m) } @@ -134,7 +134,7 @@ func BenchmarkLookup(b *testing.B) { fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} ctx := context.TODO() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) tc := test.Case{ Qname: "www.miek.nl.", Qtype: dns.TypeA, diff --git a/middleware/file/notify.go b/middleware/file/notify.go index 8a2581b84..3c6095a3b 100644 --- a/middleware/file/notify.go +++ b/middleware/file/notify.go @@ -5,6 +5,7 @@ import ( "log" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -12,7 +13,7 @@ import ( // isNotify checks if state is a notify message and if so, will *also* check if it // is from one of the configured masters. If not it will not be a valid notify // message. If the zone z is not a secondary zone the message will also be ignored. -func (z *Zone) isNotify(state middleware.State) bool { +func (z *Zone) isNotify(state request.Request) bool { if state.Req.Opcode != dns.OpcodeNotify { return false } @@ -56,7 +57,7 @@ func notify(zone string, to []string) error { func notifyAddr(c *dns.Client, m *dns.Msg, s string) error { for i := 0; i < 3; i++ { - ret, err := middleware.Exchange(c, m, s) + ret, _, err := c.Exchange(m, s) if err != nil { continue } diff --git a/middleware/file/secondary.go b/middleware/file/secondary.go index 5b1fe0cf2..9493a7fdd 100644 --- a/middleware/file/secondary.go +++ b/middleware/file/secondary.go @@ -4,8 +4,6 @@ import ( "log" "time" - "github.com/miekg/coredns/middleware" - "github.com/miekg/dns" ) @@ -75,7 +73,7 @@ func (z *Zone) shouldTransfer() (bool, error) { Transfer: for _, tr := range z.TransferFrom { Err = nil - ret, err := middleware.Exchange(c, m, tr) + ret, _, err := c.Exchange(m, tr) if err != nil || ret.Rcode != dns.RcodeSuccess { Err = err continue diff --git a/middleware/file/secondary_test.go b/middleware/file/secondary_test.go index 8eff84efb..b32c7aca7 100644 --- a/middleware/file/secondary_test.go +++ b/middleware/file/secondary_test.go @@ -4,8 +4,8 @@ import ( "fmt" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -134,7 +134,7 @@ func TestIsNotify(t *testing.T) { z := new(Zone) z.Expired = new(bool) z.origin = testZone - state := NewState(testZone, dns.TypeSOA) + state := newRequest(testZone, dns.TypeSOA) // need to set opcode state.Req.Opcode = dns.OpcodeNotify @@ -148,9 +148,9 @@ func TestIsNotify(t *testing.T) { } } -func NewState(zone string, qtype uint16) middleware.State { +func newRequest(zone string, qtype uint16) request.Request { m := new(dns.Msg) m.SetQuestion("example.com.", dns.TypeA) m.SetEdns0(4097, true) - return middleware.State{W: &test.ResponseWriter{}, Req: m} + return request.Request{W: &test.ResponseWriter{}, Req: m} } diff --git a/middleware/file/tree/elem.go b/middleware/file/tree/elem.go index 6785a6849..785d64660 100644 --- a/middleware/file/tree/elem.go +++ b/middleware/file/tree/elem.go @@ -1,9 +1,6 @@ package tree -import ( - "github.com/miekg/coredns/middleware" - "github.com/miekg/dns" -) +import "github.com/miekg/dns" type Elem struct { m map[uint16][]dns.RR @@ -91,8 +88,8 @@ func (e *Elem) Delete(rr dns.RR) (empty bool) { return } -// Less is a tree helper function that calls middleware.Less. -func Less(a *Elem, name string) int { return middleware.Less(name, a.Name()) } +// Less is a tree helper function that calls less. +func Less(a *Elem, name string) int { return less(name, a.Name()) } // Assuming the same type and name this will check if the rdata is equal as well. func equalRdata(a, b dns.RR) bool { diff --git a/middleware/canonical.go b/middleware/file/tree/less.go similarity index 91% rename from middleware/canonical.go rename to middleware/file/tree/less.go index fd30946bf..32d87b683 100644 --- a/middleware/canonical.go +++ b/middleware/file/tree/less.go @@ -1,4 +1,4 @@ -package middleware +package tree import ( "bytes" @@ -6,7 +6,7 @@ import ( "github.com/miekg/dns" ) -// Less returns <0 when a is less than b, 0 when they are equal and +// less returns <0 when a is less than b, 0 when they are equal and // >0 when a is larger than b. // The function orders names in DNSSEC canonical order: RFC 4034s section-6.1 // @@ -14,7 +14,7 @@ import ( // for a blog article on this implementation. // // The values of a and b are *not* lowercased before the comparison! -func Less(a, b string) int { +func less(a, b string) int { i := 1 aj := len(a) bj := len(b) diff --git a/middleware/canonical_test.go b/middleware/file/tree/less_test.go similarity index 96% rename from middleware/canonical_test.go rename to middleware/file/tree/less_test.go index 1a9fd0859..419b75c55 100644 --- a/middleware/canonical_test.go +++ b/middleware/file/tree/less_test.go @@ -1,4 +1,4 @@ -package middleware +package tree import ( "sort" @@ -10,7 +10,7 @@ type set []string func (p set) Len() int { return len(p) } func (p set) Swap(i, j int) { p[i], p[j] = p[j], p[i] } -func (p set) Less(i, j int) bool { d := Less(p[i], p[j]); return d <= 0 } +func (p set) Less(i, j int) bool { d := less(p[i], p[j]); return d <= 0 } func TestLess(t *testing.T) { tests := []struct { diff --git a/middleware/file/wildcard_test.go b/middleware/file/wildcard_test.go index cff1292cc..b0cc4c610 100644 --- a/middleware/file/wildcard_test.go +++ b/middleware/file/wildcard_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -57,14 +57,14 @@ func TestLookupWildcard(t *testing.T) { for _, tc := range wildcardTestCases { m := tc.Msg() - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) _, err := fm.ServeDNS(ctx, rec, m) if err != nil { t.Errorf("expected no error, got %v\n", err) return } - resp := rec.Msg() + resp := rec.Msg sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Extra)) diff --git a/middleware/file/xfr.go b/middleware/file/xfr.go index 1d87a244b..d5ea1f2d3 100644 --- a/middleware/file/xfr.go +++ b/middleware/file/xfr.go @@ -4,7 +4,7 @@ import ( "fmt" "log" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -18,7 +18,7 @@ type ( // Serve an AXFR (and fallback of IXFR) as well. func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} if !x.TransferAllowed(state) { return dns.RcodeServerFailure, nil } diff --git a/middleware/file/zone.go b/middleware/file/zone.go index b84162cbb..ff8795d77 100644 --- a/middleware/file/zone.go +++ b/middleware/file/zone.go @@ -8,8 +8,8 @@ import ( "strings" "sync" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/file/tree" + "github.com/miekg/coredns/request" "github.com/fsnotify/fsnotify" "github.com/miekg/dns" @@ -102,7 +102,7 @@ func (z *Zone) Insert(r dns.RR) error { func (z *Zone) Delete(r dns.RR) { z.Tree.Delete(r) } // TransferAllowed checks if incoming request for transferring the zone is allowed according to the ACLs. -func (z *Zone) TransferAllowed(state middleware.State) bool { +func (z *Zone) TransferAllowed(req request.Request) bool { for _, t := range z.TransferTo { if t == "*" { return true diff --git a/middleware/fs_test.go b/middleware/fs_test.go deleted file mode 100644 index 44133c4eb..000000000 --- a/middleware/fs_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package middleware - -import ( - "os" - "strings" - "testing" -) - -func TestfsPath(t *testing.T) { - if actual := fsPath(); !strings.HasSuffix(actual, ".coredns") { - t.Errorf("Expected path to be a .coredns folder, got: %v", actual) - } - - os.Setenv("COREDNSPATH", "testpath") - defer os.Setenv("COREDNSPATH", "") - if actual, expected := fsPath(), "testpath"; actual != expected { - t.Errorf("Expected path to be %v, got: %v", expected, actual) - } -} diff --git a/middleware/host.go b/middleware/host.go deleted file mode 100644 index 65dbefbca..000000000 --- a/middleware/host.go +++ /dev/null @@ -1,36 +0,0 @@ -package middleware - -import ( - "net" - "strings" - - "github.com/miekg/dns" -) - -// Host represents a host from the Corefile, may contain port. -type ( - Host string - Addr string -) - -// Normalize will return the host portion of host, stripping -// of any port. The host will also be fully qualified and lowercased. -func (h Host) Normalize() string { - // separate host and port - host, _, err := net.SplitHostPort(string(h)) - if err != nil { - host, _, _ = net.SplitHostPort(string(h) + ":") - } - return strings.ToLower(dns.Fqdn(host)) -} - -// Normalize will return a normalized address, if not port is specified -// port 53 is added, otherwise the port will be left as is. -func (a Addr) Normalize() string { - // separate host and port - addr, port, err := net.SplitHostPort(string(a)) - if err != nil { - addr, port, _ = net.SplitHostPort(string(a) + ":53") - } - return net.JoinHostPort(addr, port) -} diff --git a/middleware/kubernetes/handler.go b/middleware/kubernetes/handler.go index 1986820d5..a89dedc0f 100644 --- a/middleware/kubernetes/handler.go +++ b/middleware/kubernetes/handler.go @@ -2,16 +2,17 @@ package kubernetes import ( "fmt" - "strings" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsutil" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" ) func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} if state.QClass() != dns.ClassINET { return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET") } @@ -21,8 +22,8 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true // TODO: find an alternative to this block - if strings.HasSuffix(state.Name(), arpaSuffix) { - ip, _ := extractIP(state.Name()) + ip := dnsutil.ExtractAddressFromReverse(state.Name()) + if ip != "" { records := k.getServiceRecordForIP(ip, state.Name()) if len(records) > 0 { srvPTR := &records[0] @@ -100,7 +101,7 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M } // NoData write a nodata response to the client. -func (k Kubernetes) Err(zone string, rcode int, state middleware.State) (int, error) { +func (k Kubernetes) Err(zone string, rcode int, state request.Request) (int, error) { m := new(dns.Msg) m.SetRcode(state.Req, rcode) m.Ns = []dns.RR{k.SOA(zone, state)} diff --git a/middleware/kubernetes/kubernetes.go b/middleware/kubernetes/kubernetes.go index 569e089e0..0bd1dc7a4 100644 --- a/middleware/kubernetes/kubernetes.go +++ b/middleware/kubernetes/kubernetes.go @@ -4,13 +4,13 @@ package kubernetes import ( "errors" "log" - "strings" "time" "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/kubernetes/msg" "github.com/miekg/coredns/middleware/kubernetes/nametemplate" "github.com/miekg/coredns/middleware/kubernetes/util" + "github.com/miekg/coredns/middleware/pkg/dnsutil" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/dns" @@ -100,8 +100,8 @@ func (k *Kubernetes) getZoneForName(name string) (string, []string) { func (k *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) { // TODO: refector this. // Right now GetNamespaceFromSegmentArray do not supports PRE queries - if strings.HasSuffix(name, arpaSuffix) { - ip, _ := extractIP(name) + ip := dnsutil.ExtractAddressFromReverse(name) + if ip != "" { records := k.getServiceRecordForIP(ip, name) return records, nil } diff --git a/middleware/kubernetes/lookup.go b/middleware/kubernetes/lookup.go index 0096e1fdb..e14d2275e 100644 --- a/middleware/kubernetes/lookup.go +++ b/middleware/kubernetes/lookup.go @@ -4,21 +4,17 @@ import ( "fmt" "math" "net" - "strings" "time" "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/kubernetes/msg" + "github.com/miekg/coredns/middleware/pkg/dnsutil" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) -const ( - // arpaSuffix is the standard suffix for PTR IP reverse lookups. - arpaSuffix = ".in-addr.arpa." -) - -func (k Kubernetes) records(state middleware.State, exact bool) ([]msg.Service, error) { +func (k Kubernetes) records(state request.Request, exact bool) ([]msg.Service, error) { services, err := k.Records(state.Name(), exact) if err != nil { return nil, err @@ -28,7 +24,7 @@ func (k Kubernetes) records(state middleware.State, exact bool) ([]msg.Service, return services, nil } -func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) { +func (k Kubernetes) A(zone string, state request.Request, previousRecords []dns.RR) (records []dns.RR, err error) { services, err := k.records(state, false) if err != nil { return nil, err @@ -49,11 +45,11 @@ func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns // don't add it, and just continue continue } - if isDuplicateCNAME(newRecord, previousRecords) { + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { continue } - state1 := copyState(state, serv.Host, state.QType()) + state1 := state.NewWithQuestion(serv.Host, state.QType()) nextRecords, err := k.A(zone, state1, append(previousRecords, newRecord)) if err == nil { @@ -87,7 +83,7 @@ func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns return records, nil } -func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) { +func (k Kubernetes) AAAA(zone string, state request.Request, previousRecords []dns.RR) (records []dns.RR, err error) { services, err := k.records(state, false) if err != nil { return nil, err @@ -108,11 +104,11 @@ func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords [] // don't add it, and just continue continue } - if isDuplicateCNAME(newRecord, previousRecords) { + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { continue } - state1 := copyState(state, serv.Host, state.QType()) + state1 := state.NewWithQuestion(serv.Host, state.QType()) nextRecords, err := k.AAAA(zone, state1, append(previousRecords, newRecord)) if err == nil { @@ -149,7 +145,7 @@ func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords [] // SRV returns SRV records from kubernetes. // If the Target is not a name but an IP address, a name is created on the fly. -func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR, extra []dns.RR, err error) { +func (k Kubernetes) SRV(zone string, state request.Request) (records []dns.RR, extra []dns.RR, err error) { services, err := k.records(state, false) if err != nil { return nil, nil, err @@ -207,7 +203,7 @@ func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR, } // Internal name, we should have some info on them, either v4 or v6 // Clients expect a complete answer, because we are a recursor in their view. - state1 := copyState(state, srv.Target, dns.TypeA) + state1 := state.NewWithQuestion(srv.Target, dns.TypeA) addr, e1 := k.A(zone, state1, nil) if e1 == nil { extra = append(extra, addr...) @@ -231,21 +227,21 @@ func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR, } // Returning MX records from kubernetes not implemented. -func (k Kubernetes) MX(zone string, state middleware.State) (records []dns.RR, extra []dns.RR, err error) { +func (k Kubernetes) MX(zone string, state request.Request) (records []dns.RR, extra []dns.RR, err error) { return nil, nil, err } // Returning CNAME records from kubernetes not implemented. -func (k Kubernetes) CNAME(zone string, state middleware.State) (records []dns.RR, err error) { +func (k Kubernetes) CNAME(zone string, state request.Request) (records []dns.RR, err error) { return nil, err } // Returning TXT records from kubernetes not implemented. -func (k Kubernetes) TXT(zone string, state middleware.State) (records []dns.RR, err error) { +func (k Kubernetes) TXT(zone string, state request.Request) (records []dns.RR, err error) { return nil, err } -func (k Kubernetes) NS(zone string, state middleware.State) (records, extra []dns.RR, err error) { +func (k Kubernetes) NS(zone string, state request.Request) (records, extra []dns.RR, err error) { // NS record for this zone live in a special place, ns.dns.. Fake our lookup. // only a tad bit fishy... old := state.QName() @@ -278,7 +274,7 @@ func (k Kubernetes) NS(zone string, state middleware.State) (records, extra []dn } // SOA Record returns a SOA record. -func (k Kubernetes) SOA(zone string, state middleware.State) *dns.SOA { +func (k Kubernetes) SOA(zone string, state request.Request) *dns.SOA { header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: 300, Class: dns.ClassINET} return &dns.SOA{Hdr: header, Mbox: "hostmaster." + zone, @@ -291,9 +287,9 @@ func (k Kubernetes) SOA(zone string, state middleware.State) *dns.SOA { } } -func (k Kubernetes) PTR(zone string, state middleware.State) ([]dns.RR, error) { - reverseIP, ok := extractIP(state.Name()) - if !ok { +func (k Kubernetes) PTR(zone string, state request.Request) ([]dns.RR, error) { + reverseIP := dnsutil.ExtractAddressFromReverse(state.Name()) + if reverseIP == "" { return nil, fmt.Errorf("does not support reverse lookup for %s", state.QName()) } @@ -318,41 +314,3 @@ func (k Kubernetes) PTR(zone string, state middleware.State) ([]dns.RR, error) { } return records, nil } - -func isDuplicateCNAME(r *dns.CNAME, records []dns.RR) bool { - for _, rec := range records { - if v, ok := rec.(*dns.CNAME); ok { - if v.Target == r.Target { - return true - } - } - } - return false -} - -func copyState(state middleware.State, target string, typ uint16) middleware.State { - state1 := middleware.State{W: state.W, Req: state.Req.Copy()} - state1.Req.Question[0] = dns.Question{Name: dns.Fqdn(target), Qtype: dns.ClassINET, Qclass: typ} - return state1 -} - -// extractIP turns a standard PTR reverse record lookup name -// into an IP address -func extractIP(reverseName string) (string, bool) { - if !strings.HasSuffix(reverseName, arpaSuffix) { - return "", false - } - search := strings.TrimSuffix(reverseName, arpaSuffix) - - // reverse the segments and then combine them - segments := reverseArray(strings.Split(search, ".")) - return strings.Join(segments, "."), true -} - -func reverseArray(arr []string) []string { - for i := 0; i < len(arr)/2; i++ { - j := len(arr) - i - 1 - arr[i], arr[j] = arr[j], arr[i] - } - return arr -} diff --git a/middleware/kubernetes/setup.go b/middleware/kubernetes/setup.go index fc3c036b8..09e4478e3 100644 --- a/middleware/kubernetes/setup.go +++ b/middleware/kubernetes/setup.go @@ -68,7 +68,7 @@ func kubernetesParse(c *caddy.Controller) (Kubernetes, error) { } k8s.Zones = NormalizeZoneList(zones) - middleware.Zones(k8s.Zones).FullyQualify() + middleware.Zones(k8s.Zones).Normalize() if k8s.Zones == nil || len(k8s.Zones) < 1 { err = errors.New("Zone name must be provided for kubernetes middleware.") diff --git a/middleware/loadbalance/loadbalance_test.go b/middleware/loadbalance/loadbalance_test.go index 5f1ff2ecb..5e240be13 100644 --- a/middleware/loadbalance/loadbalance_test.go +++ b/middleware/loadbalance/loadbalance_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -56,7 +57,7 @@ func TestLoadBalance(t *testing.T) { }, } - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) for i, test := range tests { req := new(dns.Msg) @@ -71,7 +72,7 @@ func TestLoadBalance(t *testing.T) { } cname := 0 - for _, r := range rec.Msg().Answer { + for _, r := range rec.Msg.Answer { if r.Header().Rrtype != dns.TypeCNAME { break } @@ -81,7 +82,7 @@ func TestLoadBalance(t *testing.T) { t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname) } cname = 0 - for _, r := range rec.Msg().Extra { + for _, r := range rec.Msg.Extra { if r.Header().Rrtype != dns.TypeCNAME { break } diff --git a/middleware/log/log.go b/middleware/log/log.go index 32d40632a..9da5e65aa 100644 --- a/middleware/log/log.go +++ b/middleware/log/log.go @@ -7,6 +7,11 @@ import ( "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/metrics" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" + "github.com/miekg/coredns/middleware/pkg/rcode" + "github.com/miekg/coredns/middleware/pkg/replacer" + "github.com/miekg/coredns/middleware/pkg/roller" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" @@ -20,32 +25,30 @@ type Logger struct { } func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} for _, rule := range l.Rules { if middleware.Name(rule.NameScope).Matches(state.Name()) { - responseRecorder := middleware.NewResponseRecorder(w) - rcode, err := l.Next.ServeDNS(ctx, responseRecorder, r) + responseRecorder := dnsrecorder.New(w) + rc, err := l.Next.ServeDNS(ctx, responseRecorder, r) - if rcode > 0 { + if rc > 0 { // There was an error up the chain, but no response has been written yet. // The error must be handled here so the log entry will record the response size. if l.ErrorFunc != nil { - l.ErrorFunc(responseRecorder, r, rcode) + l.ErrorFunc(responseRecorder, r, rc) } else { - rc := middleware.RcodeToString(rcode) - answer := new(dns.Msg) - answer.SetRcode(r, rcode) + answer.SetRcode(r, rc) state.SizeAndDo(answer) - metrics.Report(state, metrics.Dropped, rc, answer.Len(), time.Now()) + metrics.Report(state, metrics.Dropped, rcode.ToString(rc), answer.Len(), time.Now()) w.WriteMsg(answer) } - rcode = 0 + rc = 0 } - rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue) + rep := replacer.New(r, responseRecorder, CommonLogEmptyValue) rule.Log.Println(rep.Replace(rule.Format)) - return rcode, err + return rc, err } } @@ -58,7 +61,7 @@ type Rule struct { OutputFile string Format string Log *log.Logger - Roller *middleware.LogRoller + Roller *roller.LogRoller } const ( diff --git a/middleware/log/log_test.go b/middleware/log/log_test.go index b5df1ad76..77e3f2e3c 100644 --- a/middleware/log/log_test.go +++ b/middleware/log/log_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -37,7 +37,7 @@ func TestLoggedStatus(t *testing.T) { r := new(dns.Msg) r.SetQuestion("example.org.", dns.TypeA) - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) rcode, _ := logger.ServeDNS(ctx, rec, r) if rcode != 0 { diff --git a/middleware/log/setup.go b/middleware/log/setup.go index a1e143e6b..0721090a9 100644 --- a/middleware/log/setup.go +++ b/middleware/log/setup.go @@ -6,7 +6,7 @@ import ( "os" "github.com/miekg/coredns/core/dnsserver" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/roller" "github.com/hashicorp/go-syslog" "github.com/mholt/caddy" @@ -75,13 +75,13 @@ func logParse(c *caddy.Controller) ([]Rule, error) { for c.Next() { args := c.RemainingArgs() - var logRoller *middleware.LogRoller + var logRoller *roller.LogRoller if c.NextBlock() { if c.Val() == "rotate" { if c.NextArg() { if c.Val() == "{" { var err error - logRoller, err = middleware.ParseRoller(c) + logRoller, err = roller.Parse(c) if err != nil { return nil, err } diff --git a/middleware/log/setup_test.go b/middleware/log/setup_test.go index 0a3ee63fe..b38caa52d 100644 --- a/middleware/log/setup_test.go +++ b/middleware/log/setup_test.go @@ -3,7 +3,7 @@ package log import ( "testing" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/roller" "github.com/mholt/caddy" ) @@ -68,7 +68,7 @@ func TestLogParse(t *testing.T) { NameScope: ".", OutputFile: "access.log", Format: DefaultLogFormat, - Roller: &middleware.LogRoller{ + Roller: &roller.LogRoller{ MaxSize: 2, MaxAge: 10, MaxBackups: 3, diff --git a/middleware/metrics/handler.go b/middleware/metrics/handler.go index 1bdc34897..6cac7e81e 100644 --- a/middleware/metrics/handler.go +++ b/middleware/metrics/handler.go @@ -4,13 +4,16 @@ import ( "time" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" + "github.com/miekg/coredns/middleware/pkg/rcode" + "github.com/miekg/coredns/request" "github.com/miekg/dns" "golang.org/x/net/context" ) func (m Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := middleware.State{W: w, Req: r} + state := request.Request{W: w, Req: r} qname := state.QName() zone := middleware.Zones(m.ZoneNames).Matches(qname) @@ -19,35 +22,35 @@ func (m Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) } // Record response to get status code and size of the reply. - rw := middleware.NewResponseRecorder(w) + rw := dnsrecorder.New(w) status, err := m.Next.ServeDNS(ctx, rw, r) - Report(state, zone, rw.Rcode(), rw.Size(), rw.Start()) + Report(state, zone, rcode.ToString(rw.Rcode), rw.Size, rw.Start) return status, err } // Report is a plain reporting function that the server can use for REFUSED and other // queries that are turned down because they don't match any middleware. -func Report(state middleware.State, zone, rcode string, size int, start time.Time) { +func Report(req request.Request, zone, rcode string, size int, start time.Time) { if requestCount == nil { // no metrics are enabled return } // Proto and Family - net := state.Proto() + net := req.Proto() fam := "1" - if state.Family() == 2 { + if req.Family() == 2 { fam = "2" } - typ := state.QType() + typ := req.QType() requestCount.WithLabelValues(zone, net, fam).Inc() requestDuration.WithLabelValues(zone).Observe(float64(time.Since(start) / time.Millisecond)) - if state.Do() { + if req.Do() { requestDo.WithLabelValues(zone).Inc() } @@ -59,10 +62,10 @@ func Report(state middleware.State, zone, rcode string, size int, start time.Tim if typ == dns.TypeIXFR || typ == dns.TypeAXFR { responseTransferSize.WithLabelValues(zone, net).Observe(float64(size)) - requestTransferSize.WithLabelValues(zone, net).Observe(float64(state.Size())) + requestTransferSize.WithLabelValues(zone, net).Observe(float64(req.Size())) } else { responseSize.WithLabelValues(zone, net).Observe(float64(size)) - requestSize.WithLabelValues(zone, net).Observe(float64(state.Size())) + requestSize.WithLabelValues(zone, net).Observe(float64(req.Size())) } responseRcode.WithLabelValues(zone, rcode).Inc() diff --git a/middleware/name.go b/middleware/name.go deleted file mode 100644 index 616ae68bd..000000000 --- a/middleware/name.go +++ /dev/null @@ -1,25 +0,0 @@ -package middleware - -import ( - "strings" - - "github.com/miekg/dns" -) - -// Name represents a domain name. -type Name string - -// Matches checks to see if other is a subdomain (or the same domain) of n. -// This method assures that names can be easily and consistently matched. -func (n Name) Matches(child string) bool { - if dns.Name(n) == dns.Name(child) { - return true - } - - return dns.IsSubDomain(string(n), child) -} - -// Normalize lowercases and makes n fully qualified. -func (n Name) Normalize() string { - return strings.ToLower(dns.Fqdn(string(n))) -} diff --git a/middleware/normalize.go b/middleware/normalize.go new file mode 100644 index 000000000..e5b747620 --- /dev/null +++ b/middleware/normalize.go @@ -0,0 +1,78 @@ +package middleware + +import ( + "net" + "strings" + + "github.com/miekg/dns" +) + +type Zones []string + +// Matches checks is qname is a subdomain of any of the zones in z. The match +// will return the most specific zones that matches other. The empty string +// signals a not found condition. +func (z Zones) Matches(qname string) string { + zone := "" + for _, zname := range z { + if dns.IsSubDomain(zname, qname) { + // TODO(miek): hmm, add test for this case + if len(zname) > len(zone) { + zone = zname + } + } + } + return zone +} + +// Normalize fully qualifies all zones in z. +func (z Zones) Normalize() { + for i, _ := range z { + z[i] = Name(z[i]).Normalize() + } +} + +// Name represents a domain name. +type Name string + +// Matches checks to see if other is a subdomain (or the same domain) of n. +// This method assures that names can be easily and consistently matched. +func (n Name) Matches(child string) bool { + if dns.Name(n) == dns.Name(child) { + return true + } + + return dns.IsSubDomain(string(n), child) +} + +// Normalize lowercases and makes n fully qualified. +func (n Name) Normalize() string { return strings.ToLower(dns.Fqdn(string(n))) } + +// Host represents a host from the Corefile, may contain port. +type ( + Host string + Addr string +) + +// Normalize will return the host portion of host, stripping +// of any port. The host will also be fully qualified and lowercased. +func (h Host) Normalize() string { + // separate host and port + host, _, err := net.SplitHostPort(string(h)) + if err != nil { + host, _, _ = net.SplitHostPort(string(h) + ":") + } + return Name(host).Normalize() +} + +// Normalize will return a normalized address, if not port is specified +// port 53 is added, otherwise the port will be left as is. +func (a Addr) Normalize() string { + // separate host and port + addr, port, err := net.SplitHostPort(string(a)) + if err != nil { + addr, port, _ = net.SplitHostPort(string(a) + ":53") + } + // TODO(miek): lowercase it? + return net.JoinHostPort(addr, port) +} diff --git a/middleware/pkg/dnsrecorder/recorder.go b/middleware/pkg/dnsrecorder/recorder.go new file mode 100644 index 000000000..9bf045e91 --- /dev/null +++ b/middleware/pkg/dnsrecorder/recorder.go @@ -0,0 +1,57 @@ +package dnsrecorder + +import ( + "time" + + "github.com/miekg/dns" +) + +// Recorder is a type of ResponseWriter that captures +// the rcode code written to it and also the size of the message +// written in the response. A rcode code does not have +// to be written, however, in which case 0 must be assumed. +// It is best to have the constructor initialize this type +// with that default status code. +type Recorder struct { + dns.ResponseWriter + Rcode int + Size int + Msg *dns.Msg + Start time.Time +} + +// New makes and returns a new Recorder, +// which captures the DNS rcode from the ResponseWriter +// and also the length of the response message written through it. +func New(w dns.ResponseWriter) *Recorder { + return &Recorder{ + ResponseWriter: w, + Rcode: 0, + Msg: nil, + Start: time.Now(), + } +} + +// WriteMsg records the status code and calls the +// underlying ResponseWriter's WriteMsg method. +func (r *Recorder) WriteMsg(res *dns.Msg) error { + r.Rcode = res.Rcode + // We may get called multiple times (axfr for instance). + // Save the last message, but add the sizes. + r.Size += res.Len() + r.Msg = res + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the size of the message that gets written. +func (r *Recorder) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + if err == nil { + r.Size += n + } + return n, err +} + +// Hijack implements dns.Hijacker. It simply wraps the underlying +// ResponseWriter's Hijack method if there is one, or returns an error. +func (r *Recorder) Hijack() { r.ResponseWriter.Hijack(); return } diff --git a/middleware/recorder_test.go b/middleware/pkg/dnsrecorder/recorder_test.go similarity index 97% rename from middleware/recorder_test.go rename to middleware/pkg/dnsrecorder/recorder_test.go index 30931a0e3..c9c2f6ce4 100644 --- a/middleware/recorder_test.go +++ b/middleware/pkg/dnsrecorder/recorder_test.go @@ -1,4 +1,4 @@ -package middleware +package dnsrecorder /* func TestNewResponseRecorder(t *testing.T) { diff --git a/middleware/pkg/dnsutil/cname.go b/middleware/pkg/dnsutil/cname.go new file mode 100644 index 000000000..281e03218 --- /dev/null +++ b/middleware/pkg/dnsutil/cname.go @@ -0,0 +1,15 @@ +package dnsutil + +import "github.com/miekg/dns" + +// DuplicateCNAME returns true if r already exists in records. +func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool { + for _, rec := range records { + if v, ok := rec.(*dns.CNAME); ok { + if v.Target == r.Target { + return true + } + } + } + return false +} diff --git a/middleware/pkg/dnsutil/reverse.go b/middleware/pkg/dnsutil/reverse.go new file mode 100644 index 000000000..a360432f3 --- /dev/null +++ b/middleware/pkg/dnsutil/reverse.go @@ -0,0 +1,40 @@ +package dnsutil + +import "strings" + +// ExtractAddressFromReverse turns a standard PTR reverse record name +// into an IP address. This works for ipv4 or ipv6. +// +// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion +// failes the empty string is returned. +func ExtractAddressFromReverse(reverseName string) string { + search := "" + + switch { + case strings.HasSuffix(reverseName, v4arpaSuffix): + search = strings.TrimSuffix(reverseName, v4arpaSuffix) + case strings.HasSuffix(reverseName, v6arpaSuffix): + search = strings.TrimSuffix(reverseName, v6arpaSuffix) + default: + return "" + } + + // Reverse the segments and then combine them. + segments := reverse(strings.Split(search, ".")) + return strings.Join(segments, ".") +} + +func reverse(slice []string) []string { + for i := 0; i < len(slice)/2; i++ { + j := len(slice) - i - 1 + slice[i], slice[j] = slice[j], slice[i] + } + return slice +} + +const ( + // v4arpaSuffix is the reverse tree suffix for v4 IP addresses. + v4arpaSuffix = ".in-addr.arpa." + // v6arpaSuffix is the reverse tree suffix for v6 IP addresses. + v6arpaSuffix = ".ip6.arpa." +) diff --git a/middleware/edns.go b/middleware/pkg/edns/edns.go similarity index 76% rename from middleware/edns.go rename to middleware/pkg/edns/edns.go index c60bbc44d..6704066b0 100644 --- a/middleware/edns.go +++ b/middleware/pkg/edns/edns.go @@ -1,4 +1,4 @@ -package middleware +package edns import ( "errors" @@ -6,11 +6,11 @@ import ( "github.com/miekg/dns" ) -// Edns0Version checks the EDNS version in the request. If error +// Version checks the EDNS version in the request. If error // is nil everything is OK and we can invoke the middleware. If non-nil, the // returned Msg is valid to be returned to the client (and should). For some // reason this response should not contain a question RR in the question section. -func Edns0Version(req *dns.Msg) (*dns.Msg, error) { +func Version(req *dns.Msg) (*dns.Msg, error) { opt := req.IsEdns0() if opt == nil { return nil, nil @@ -33,8 +33,8 @@ func Edns0Version(req *dns.Msg) (*dns.Msg, error) { return m, errors.New("EDNS0 BADVERS") } -// edns0Size returns a normalized size based on proto. -func edns0Size(proto string, size int) int { +// Size returns a normalized size based on proto. +func Size(proto string, size int) int { if proto == "tcp" { return dns.MaxMsgSize } diff --git a/middleware/edns_test.go b/middleware/pkg/edns/edns_test.go similarity index 75% rename from middleware/edns_test.go rename to middleware/pkg/edns/edns_test.go index 7b4e6fc66..89ac6d2ec 100644 --- a/middleware/edns_test.go +++ b/middleware/pkg/edns/edns_test.go @@ -1,4 +1,4 @@ -package middleware +package edns import ( "testing" @@ -6,21 +6,21 @@ import ( "github.com/miekg/dns" ) -func TestEdns0Version(t *testing.T) { +func TestVersion(t *testing.T) { m := ednsMsg() m.Extra[0].(*dns.OPT).SetVersion(2) - _, err := Edns0Version(m) + _, err := Version(m) if err == nil { t.Errorf("expected wrong version, but got OK") } } -func TestEdns0VersionNoEdns(t *testing.T) { +func TestVersionNoEdns(t *testing.T) { m := ednsMsg() m.Extra = nil - _, err := Edns0Version(m) + _, err := Version(m) if err != nil { t.Errorf("expected no error, but got one: %s", err) } diff --git a/middleware/rcode.go b/middleware/pkg/rcode/rcode.go similarity index 72% rename from middleware/rcode.go rename to middleware/pkg/rcode/rcode.go index 989f90fdd..006440071 100644 --- a/middleware/rcode.go +++ b/middleware/pkg/rcode/rcode.go @@ -1,4 +1,4 @@ -package middleware +package rcode import ( "strconv" @@ -6,7 +6,7 @@ import ( "github.com/miekg/dns" ) -func RcodeToString(rcode int) string { +func ToString(rcode int) string { if str, ok := dns.RcodeToString[rcode]; ok { return str } diff --git a/middleware/replacer.go b/middleware/pkg/replacer/replacer.go similarity index 76% rename from middleware/replacer.go rename to middleware/pkg/replacer/replacer.go index a2fac3113..e90f29368 100644 --- a/middleware/replacer.go +++ b/middleware/pkg/replacer/replacer.go @@ -1,10 +1,13 @@ -package middleware +package replacer import ( "strconv" "strings" "time" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" + "github.com/miekg/coredns/request" + "github.com/miekg/dns" ) @@ -22,46 +25,43 @@ type replacer struct { emptyValue string } -// NewReplacer makes a new replacer based on r and rr. +// New makes a new replacer based on r and rr. // Do not create a new replacer until r and rr have all // the needed values, because this function copies those // values into the replacer. rr may be nil if it is not // available. emptyValue should be the string that is used // in place of empty string (can still be empty string). -func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer { - state := State{W: rr, Req: r} +func New(r *dns.Msg, rr *dnsrecorder.Recorder, emptyValue string) Replacer { + req := request.Request{W: rr, Req: r} rep := replacer{ replacements: map[string]string{ - "{type}": state.Type(), - "{name}": state.Name(), - "{class}": state.Class(), - "{proto}": state.Proto(), + "{type}": req.Type(), + "{name}": req.Name(), + "{class}": req.Class(), + "{proto}": req.Proto(), "{when}": func() string { return time.Now().Format(timeFormat) }(), - "{remote}": state.IP(), - "{port}": func() string { - p, _ := state.Port() - return p - }(), + "{remote}": req.IP(), + "{port}": req.Port(), }, emptyValue: emptyValue, } if rr != nil { - rcode := dns.RcodeToString[rr.rcode] + rcode := dns.RcodeToString[rr.Rcode] if rcode == "" { - rcode = strconv.Itoa(rr.rcode) + rcode = strconv.Itoa(rr.Rcode) } rep.replacements["{rcode}"] = rcode - rep.replacements["{size}"] = strconv.Itoa(rr.size) - rep.replacements["{duration}"] = time.Since(rr.start).String() + rep.replacements["{size}"] = strconv.Itoa(rr.Size) + rep.replacements["{duration}"] = time.Since(rr.Start).String() } // Header placeholders (case-insensitive) rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id)) rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(int(r.Opcode)) - rep.replacements[headerReplacer+"do}"] = boolToString(state.Do()) - rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(state.Size()) + rep.replacements[headerReplacer+"do}"] = boolToString(req.Do()) + rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(req.Size()) return rep } diff --git a/middleware/replacer_test.go b/middleware/pkg/replacer/replacer_test.go similarity index 99% rename from middleware/replacer_test.go rename to middleware/pkg/replacer/replacer_test.go index 378e4083d..09904c1a9 100644 --- a/middleware/replacer_test.go +++ b/middleware/pkg/replacer/replacer_test.go @@ -1,4 +1,4 @@ -package middleware +package replacer /* func TestNewReplacer(t *testing.T) { diff --git a/middleware/classify.go b/middleware/pkg/response/classify.go similarity index 62% rename from middleware/classify.go rename to middleware/pkg/response/classify.go index 72c131157..adbaa6526 100644 --- a/middleware/classify.go +++ b/middleware/pkg/response/classify.go @@ -1,19 +1,19 @@ -package middleware +package response import "github.com/miekg/dns" -type MsgType int +type Type int const ( - Success MsgType = iota - NameError // NXDOMAIN in header, SOA in auth. - NoData // NOERROR in header, SOA in auth. - Delegation // NOERROR in header, NS in auth, optionally fluff in additional (not checked). - OtherError // Don't cache these. + Success Type = iota + NameError // NXDOMAIN in header, SOA in auth. + NoData // NOERROR in header, SOA in auth. + Delegation // NOERROR in header, NS in auth, optionally fluff in additional (not checked). + OtherError // Don't cache these. ) -// Classify classifies a message, it returns the MessageType. -func Classify(m *dns.Msg) (MsgType, *dns.OPT) { +// Classify classifies a message, it returns the Type. +func Classify(m *dns.Msg) (Type, *dns.OPT) { opt := m.IsEdns0() if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess { diff --git a/middleware/classify_test.go b/middleware/pkg/response/classify_test.go similarity index 97% rename from middleware/classify_test.go rename to middleware/pkg/response/classify_test.go index 26c52db55..33ac815e3 100644 --- a/middleware/classify_test.go +++ b/middleware/pkg/response/classify_test.go @@ -1,4 +1,4 @@ -package middleware +package response import ( "testing" diff --git a/middleware/roller.go b/middleware/pkg/roller/roller.go similarity index 94% rename from middleware/roller.go rename to middleware/pkg/roller/roller.go index 81ff71c44..c60494736 100644 --- a/middleware/roller.go +++ b/middleware/pkg/roller/roller.go @@ -1,4 +1,4 @@ -package middleware +package roller import ( "io" @@ -8,7 +8,7 @@ import ( "gopkg.in/natefinch/lumberjack.v2" ) -func ParseRoller(c *caddy.Controller) (*LogRoller, error) { +func Parse(c *caddy.Controller) (*LogRoller, error) { var size, age, keep int // This is kind of a hack to support nested blocks: // As we are already in a block: either log or errors, diff --git a/singleflight/singleflight.go b/middleware/pkg/singleflight/singleflight.go similarity index 100% rename from singleflight/singleflight.go rename to middleware/pkg/singleflight/singleflight.go diff --git a/singleflight/singleflight_test.go b/middleware/pkg/singleflight/singleflight_test.go similarity index 100% rename from singleflight/singleflight_test.go rename to middleware/pkg/singleflight/singleflight_test.go diff --git a/middleware/fs.go b/middleware/pkg/storage/fs.go similarity index 64% rename from middleware/fs.go rename to middleware/pkg/storage/fs.go index 5970d0e99..3ee14b7ed 100644 --- a/middleware/fs.go +++ b/middleware/pkg/storage/fs.go @@ -1,25 +1,39 @@ -package middleware +package storage import ( "net/http" "os" + "path" "path/filepath" "runtime" ) -// dir wraps http.Dir that restrict file access to a specific directory tree. +// dir wraps an http.Dir that restrict file access to a specific directory tree, see http.Dir's documentation +// for methods for accessing files. type dir http.Dir // CoreDir is the directory where middleware can store assets, like zone files after a zone transfer // or public and private keys or anything else a middleware might need. The convention is to place -// assets in a subdirectory named after the fully qualified zone. +// assets in a subdirectory named after the zone prefixed with "D", to prevent the root zone become a hidden directory. // -// example.org./Kexample.key +// Dexample.org/Kexample.org.key +// +// Note that subzone(s) under example.org are places in the own directory under CoreDir: +// +// Dexample.org/... +// Db.example.org/... // // CoreDir will default to "$HOME/.coredns" on Unix, but it's location can be overriden with the COREDNSPATH // environment variable. var CoreDir dir = dir(fsPath()) +func (d dir) Zone(z string) dir { + if z != "." && z[len(z)-2] == '.' { + return dir(path.Join(string(d), "D"+z[:len(z)-1])) + } + return dir(path.Join(string(d), "D"+z)) +} + // fsPath returns the path to the directory where the application may store data. // If COREDNSPATH env variable. is set, that value is used. Otherwise, the path is // the result of evaluating "$HOME/.coredns". diff --git a/middleware/pkg/storage/fs_test.go b/middleware/pkg/storage/fs_test.go new file mode 100644 index 000000000..f7e8ccf9d --- /dev/null +++ b/middleware/pkg/storage/fs_test.go @@ -0,0 +1,42 @@ +package storage + +import ( + "os" + "path" + "strings" + "testing" +) + +func TestfsPath(t *testing.T) { + if actual := fsPath(); !strings.HasSuffix(actual, ".coredns") { + t.Errorf("Expected path to be a .coredns folder, got: %v", actual) + } + + os.Setenv("COREDNSPATH", "testpath") + defer os.Setenv("COREDNSPATH", "") + if actual, expected := fsPath(), "testpath"; actual != expected { + t.Errorf("Expected path to be %v, got: %v", expected, actual) + } +} + +func TestZone(t *testing.T) { + for _, ts := range []string{"example.org.", "example.org"} { + d := CoreDir.Zone(ts) + actual := path.Base(string(d)) + expected := "D" + ts + if actual != expected { + t.Errorf("Expected path to be %v, got %v", actual, expected) + } + } +} + +func TestZoneRoot(t *testing.T) { + for _, ts := range []string{"."} { + d := CoreDir.Zone(ts) + actual := path.Base(string(d)) + expected := "D" + ts + if actual != expected { + t.Errorf("Expected path to be %v, got %v", actual, expected) + } + } +} diff --git a/middleware/proxy/lookup.go b/middleware/proxy/lookup.go index 22a0c77d6..e2a3a0c77 100644 --- a/middleware/proxy/lookup.go +++ b/middleware/proxy/lookup.go @@ -7,7 +7,8 @@ import ( "sync/atomic" "time" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" + "github.com/miekg/dns" ) @@ -54,18 +55,19 @@ func New(hosts []string) Proxy { // Lookup will use name and type to forge a new message and will send that upstream. It will // set any EDNS0 options correctly so that downstream will be able to process the reply. // Lookup is not suitable for forwarding request. Ssee for that. -func (p Proxy) Lookup(state middleware.State, name string, tpe uint16) (*dns.Msg, error) { +func (p Proxy) Lookup(state request.Request, name string, tpe uint16) (*dns.Msg, error) { req := new(dns.Msg) req.SetQuestion(name, tpe) state.SizeAndDo(req) + return p.lookup(state, req) } -func (p Proxy) Forward(state middleware.State) (*dns.Msg, error) { +func (p Proxy) Forward(state request.Request) (*dns.Msg, error) { return p.lookup(state, state.Req) } -func (p Proxy) lookup(state middleware.State, r *dns.Msg) (*dns.Msg, error) { +func (p Proxy) lookup(state request.Request, r *dns.Msg) (*dns.Msg, error) { var ( reply *dns.Msg err error @@ -84,9 +86,9 @@ func (p Proxy) lookup(state middleware.State, r *dns.Msg) (*dns.Msg, error) { atomic.AddInt64(&host.Conns, 1) if state.Proto() == "tcp" { - reply, err = middleware.Exchange(p.Client.TCP, r, host.Name) + reply, _, err = p.Client.TCP.Exchange(r, host.Name) } else { - reply, err = middleware.Exchange(p.Client.UDP, r, host.Name) + reply, _, err = p.Client.UDP.Exchange(r, host.Name) } atomic.AddInt64(&host.Conns, -1) diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index 071452ecb..c452e296c 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -2,7 +2,7 @@ package proxy import ( - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -20,10 +20,10 @@ func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) ) switch { - case middleware.Proto(w) == "tcp": - reply, err = middleware.Exchange(p.Client.TCP, r, p.Host) + case request.Proto(w) == "tcp": // TODO(miek): keep this in request + reply, _, err = p.Client.TCP.Exchange(r, p.Host) default: - reply, err = middleware.Exchange(p.Client.UDP, r, p.Host) + reply, _, err = p.Client.UDP.Exchange(r, p.Host) } if reply != nil && reply.Truncated { diff --git a/middleware/recorder.go b/middleware/recorder.go deleted file mode 100644 index d1e466ec3..000000000 --- a/middleware/recorder.go +++ /dev/null @@ -1,72 +0,0 @@ -package middleware - -import ( - "time" - - "github.com/miekg/dns" -) - -// ResponseRecorder is a type of ResponseWriter that captures -// the rcode code written to it and also the size of the message -// written in the response. A rcode code does not have -// to be written, however, in which case 0 must be assumed. -// It is best to have the constructor initialize this type -// with that default status code. -type ResponseRecorder struct { - dns.ResponseWriter - rcode int - size int - msg *dns.Msg - start time.Time -} - -// NewResponseRecorder makes and returns a new responseRecorder, -// which captures the DNS rcode from the ResponseWriter -// and also the length of the response message written through it. -func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder { - return &ResponseRecorder{ - ResponseWriter: w, - rcode: 0, - msg: nil, - start: time.Now(), - } -} - -// WriteMsg records the status code and calls the -// underlying ResponseWriter's WriteMsg method. -func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error { - r.rcode = res.Rcode - // We may get called multiple times (axfr for instance). - // Save the last message, but add the sizes. - r.size += res.Len() - r.msg = res - return r.ResponseWriter.WriteMsg(res) -} - -// Write is a wrapper that records the size of the message that gets written. -func (r *ResponseRecorder) Write(buf []byte) (int, error) { - n, err := r.ResponseWriter.Write(buf) - if err == nil { - r.size += n - } - return n, err -} - -// Size returns the size. -func (r *ResponseRecorder) Size() int { return r.size } - -// Rcode returns the rcode. -func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) } - -// Start returns the start time of the ResponseRecorder. -func (r *ResponseRecorder) Start() time.Time { return r.start } - -// Msg returns the written message from the ResponseRecorder. -func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg } - -// Hijack implements dns.Hijacker. It simply wraps the underlying -// ResponseWriter's Hijack method if there is one, or returns an error. -func (r *ResponseRecorder) Hijack() { - r.ResponseWriter.Hijack() - return -} diff --git a/middleware/rewrite/condition.go b/middleware/rewrite/condition.go index 3ccde21f9..e9c565abc 100644 --- a/middleware/rewrite/condition.go +++ b/middleware/rewrite/condition.go @@ -5,7 +5,8 @@ import ( "regexp" "strings" - "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/replacer" + "github.com/miekg/dns" ) @@ -25,8 +26,8 @@ func operatorError(operator string) error { return fmt.Errorf("Invalid operator %v", operator) } -func newReplacer(r *dns.Msg) middleware.Replacer { - return middleware.NewReplacer(r, nil, "") +func newReplacer(r *dns.Msg) replacer.Replacer { + return replacer.New(r, nil, "") } // condition is a rewrite condition. diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index 9d8450a3e..2f2f404c9 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -117,143 +117,3 @@ func (s SimpleRule) Rewrite(r *dns.Msg) Result { } return RewriteIgnored } - -/* -// ComplexRule is a rewrite rule based on a regular expression -type ComplexRule struct { - // Path base. Request to this path and subpaths will be rewritten - Base string - - // Path to rewrite to - To string - - // If set, neither performs rewrite nor proceeds - // with request. Only returns code. - Status int - - // Extensions to filter by - Exts []string - - // Rewrite conditions - Ifs []If - - *regexp.Regexp -} - -// NewComplexRule creates a new RegexpRule. It returns an error if regexp -// pattern (pattern) or extensions (ext) are invalid. -func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) { - // validate regexp if present - var r *regexp.Regexp - if pattern != "" { - var err error - r, err = regexp.Compile(pattern) - if err != nil { - return nil, err - } - } - - // validate extensions if present - for _, v := range ext { - if len(v) < 2 || (len(v) < 3 && v[0] == '!') { - // check if no extension is specified - if v != "/" && v != "!/" { - return nil, fmt.Errorf("invalid extension %v", v) - } - } - } - - return &ComplexRule{ - Base: base, - To: to, - Status: status, - Exts: ext, - Ifs: ifs, - Regexp: r, - }, nil -} - -// Rewrite rewrites the internal location of the current request. -func (r *ComplexRule) Rewrite(req *dns.Msg) (re Result) { - rPath := req.URL.Path - replacer := newReplacer(req) - - // validate base - if !middleware.Path(rPath).Matches(r.Base) { - return - } - - // validate extensions - if !r.matchExt(rPath) { - return - } - - // validate regexp if present - if r.Regexp != nil { - // include trailing slash in regexp if present - start := len(r.Base) - if strings.HasSuffix(r.Base, "/") { - start-- - } - - matches := r.FindStringSubmatch(rPath[start:]) - switch len(matches) { - case 0: - // no match - return - default: - // set regexp match variables {1}, {2} ... - for i := 1; i < len(matches); i++ { - replacer.Set(fmt.Sprint(i), matches[i]) - } - } - } - - // validate rewrite conditions - for _, i := range r.Ifs { - if !i.True(req) { - return - } - } - - // if status is present, stop rewrite and return it. - if r.Status != 0 { - return RewriteStatus - } - - // attempt rewrite - return To(fs, req, r.To, replacer) -} - -// matchExt matches rPath against registered file extensions. -// Returns true if a match is found and false otherwise. -func (r *ComplexRule) matchExt(rPath string) bool { - f := filepath.Base(rPath) - ext := path.Ext(f) - if ext == "" { - ext = "/" - } - - mustUse := false - for _, v := range r.Exts { - use := true - if v[0] == '!' { - use = false - v = v[1:] - } - - if use { - mustUse = true - } - - if ext == v { - return use - } - } - - if mustUse { - return false - } - return true -} -*/ diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go index 50aa64ad2..2dc658738 100644 --- a/middleware/rewrite/rewrite_test.go +++ b/middleware/rewrite/rewrite_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/pkg/dnsrecorder" "github.com/miekg/coredns/middleware/test" "github.com/miekg/dns" @@ -50,10 +51,10 @@ func TestRewrite(t *testing.T) { m.SetQuestion(tc.from, tc.fromT) m.Question[0].Qclass = tc.fromC - rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) + rec := dnsrecorder.New(&test.ResponseWriter{}) rw.ServeDNS(ctx, rec, m) - resp := rec.Msg() + resp := rec.Msg if resp.Question[0].Name != tc.to { t.Errorf("Test %d: Expected Name to be '%s' but was '%s'", i, tc.to, resp.Question[0].Name) } diff --git a/middleware/state_test.go b/middleware/state_test.go deleted file mode 100644 index ae4f1407a..000000000 --- a/middleware/state_test.go +++ /dev/null @@ -1,289 +0,0 @@ -package middleware - -import ( - "testing" - - "github.com/miekg/coredns/middleware/test" - - "github.com/miekg/dns" -) - -func TestStateDo(t *testing.T) { - st := testState() - - st.Do() - if st.do == 0 { - t.Fatalf("expected st.do to be set") - } -} - -func TestStateRemote(t *testing.T) { - st := testState() - if st.IP() != "10.240.0.1" { - t.Fatalf("wrong IP from state") - } - p, err := st.Port() - if err != nil { - t.Fatalf("failed to get Port from state") - } - if p != "40212" { - t.Fatalf("wrong port from state") - } -} - -func BenchmarkStateDo(b *testing.B) { - st := testState() - - for i := 0; i < b.N; i++ { - st.Do() - } -} - -func BenchmarkStateSize(b *testing.B) { - st := testState() - - for i := 0; i < b.N; i++ { - st.Size() - } -} - -func testState() State { - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) - m.SetEdns0(4097, true) - return State{W: &test.ResponseWriter{}, Req: m} -} - -/* -func TestHeader(t *testing.T) { - state := getContextOrFail(t) - - headerKey, headerVal := "Header1", "HeaderVal1" - state.Req.Header.Add(headerKey, headerVal) - - actualHeaderVal := state.Header(headerKey) - if actualHeaderVal != headerVal { - t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal) - } - - missingHeaderVal := state.Header("not-existing") - if missingHeaderVal != "" { - t.Errorf("Expected empty header value, found %s", missingHeaderVal) - } -} - -func TestIP(t *testing.T) { - state := getContextOrFail(t) - - tests := []struct { - inputRemoteAddr string - expectedIP string - }{ - // Test 0 - ipv4 with port - {"1.1.1.1:1111", "1.1.1.1"}, - // Test 1 - ipv4 without port - {"1.1.1.1", "1.1.1.1"}, - // Test 2 - ipv6 with port - {"[::1]:11", "::1"}, - // Test 3 - ipv6 without port and brackets - {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"}, - // Test 4 - ipv6 with zone and port - {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`}, - } - - for i, test := range tests { - testPrefix := getTestPrefix(i) - - state.Req.RemoteAddr = test.inputRemoteAddr - actualIP := state.IP() - - if actualIP != test.expectedIP { - t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP) - } - } -} - -func TestURL(t *testing.T) { - state := getContextOrFail(t) - - inputURL := "http://localhost" - state.Req.RequestURI = inputURL - - if inputURL != state.URI() { - t.Errorf("Expected url %s, found %s", inputURL, state.URI()) - } -} - -func TestHost(t *testing.T) { - tests := []struct { - input string - expectedHost string - shouldErr bool - }{ - { - input: "localhost:123", - expectedHost: "localhost", - shouldErr: false, - }, - { - input: "localhost", - expectedHost: "localhost", - shouldErr: false, - }, - { - input: "[::]", - expectedHost: "", - shouldErr: true, - }, - } - - for _, test := range tests { - testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr) - } -} - -func TestPort(t *testing.T) { - tests := []struct { - input string - expectedPort string - shouldErr bool - }{ - { - input: "localhost:123", - expectedPort: "123", - shouldErr: false, - }, - { - input: "localhost", - expectedPort: "80", // assuming 80 is the default port - shouldErr: false, - }, - { - input: ":8080", - expectedPort: "8080", - shouldErr: false, - }, - { - input: "[::]", - expectedPort: "", - shouldErr: true, - }, - } - - for _, test := range tests { - testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr) - } -} - -func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) { - state := getContextOrFail(t) - - state.Req.Host = input - var actualResult, testedObject string - var err error - - if isTestingHost { - actualResult, err = state.Host() - testedObject = "host" - } else { - actualResult, err = state.Port() - testedObject = "port" - } - - if shouldErr && err == nil { - t.Errorf("Expected error, found nil!") - return - } - - if !shouldErr && err != nil { - t.Errorf("Expected no error, found %s", err) - return - } - - if actualResult != expectedResult { - t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult) - } -} - -func TestPathMatches(t *testing.T) { - state := getContextOrFail(t) - - tests := []struct { - urlStr string - pattern string - shouldMatch bool - }{ - // Test 0 - { - urlStr: "http://localhost/", - pattern: "", - shouldMatch: true, - }, - // Test 1 - { - urlStr: "http://localhost", - pattern: "", - shouldMatch: true, - }, - // Test 1 - { - urlStr: "http://localhost/", - pattern: "/", - shouldMatch: true, - }, - // Test 3 - { - urlStr: "http://localhost/?param=val", - pattern: "/", - shouldMatch: true, - }, - // Test 4 - { - urlStr: "http://localhost/dir1/dir2", - pattern: "/dir2", - shouldMatch: false, - }, - // Test 5 - { - urlStr: "http://localhost/dir1/dir2", - pattern: "/dir1", - shouldMatch: true, - }, - // Test 6 - { - urlStr: "http://localhost:444/dir1/dir2", - pattern: "/dir1", - shouldMatch: true, - }, - } - - for i, test := range tests { - testPrefix := getTestPrefix(i) - var err error - state.Req.URL, err = url.Parse(test.urlStr) - if err != nil { - t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err) - } - - matches := state.PathMatches(test.pattern) - if matches != test.shouldMatch { - t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches) - } - } -} - -func initTestContext() (Context, error) { - body := bytes.NewBufferString("request body") - request, err := http.NewRequest("GET", "https://localhost", body) - if err != nil { - return Context{}, err - } - - return Context{Root: http.Dir(os.TempDir()), Req: request}, nil -} - - -func getTestPrefix(testN int) string { - return fmt.Sprintf("Test [%d]: ", testN) -} -*/ diff --git a/middleware/zone.go b/middleware/zone.go deleted file mode 100644 index 0f3cd0edc..000000000 --- a/middleware/zone.go +++ /dev/null @@ -1,27 +0,0 @@ -package middleware - -import "github.com/miekg/dns" - -type Zones []string - -// Matches checks is qname is a subdomain of any of the zones in z. The match -// will return the most specific zones that matches other. The empty string -// signals a not found condition. -func (z Zones) Matches(qname string) string { - zone := "" - for _, zname := range z { - if dns.IsSubDomain(zname, qname) { - if len(zname) > len(zone) { - zone = zname - } - } - } - return zone -} - -// FullyQualify fully qualifies all zones in z. -func (z Zones) FullyQualify() { - for i, _ := range z { - z[i] = dns.Fqdn(z[i]) - } -} diff --git a/middleware/state.go b/request/request.go similarity index 58% rename from middleware/state.go rename to request/request.go index 4299641ab..49d7312af 100644 --- a/middleware/state.go +++ b/request/request.go @@ -1,17 +1,16 @@ -package middleware +package request import ( "net" "strings" - "time" + + "github.com/miekg/coredns/middleware/pkg/edns" "github.com/miekg/dns" ) -// This file contains the state functions available for use in the middlewares. - -// State contains some connection state and is useful in middleware. -type State struct { +// Request contains some connection state and is useful in middleware. +type Request struct { Req *dns.Msg W dns.ResponseWriter @@ -24,38 +23,41 @@ type State struct { name string } -// Now returns the current timestamp in the specified format. -func (s *State) Now(format string) string { return time.Now().Format(format) } - -// NowDate returns the current date/time that can be used in other time functions. -func (s *State) NowDate() time.Time { return time.Now() } +// NewWithQuestion returns a new request based on the old, but with a new question +// section in the request. +func (r *Request) NewWithQuestion(name string, typ uint16) Request { + req1 := Request{W: r.W, Req: r.Req.Copy()} + req1.Req.Question[0] = dns.Question{Name: dns.Fqdn(name), Qtype: dns.ClassINET, Qclass: typ} + return req1 +} // IP gets the (remote) IP address of the client making the request. -func (s *State) IP() string { - ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String()) +func (r *Request) IP() string { + ip, _, err := net.SplitHostPort(r.W.RemoteAddr().String()) if err != nil { - return s.W.RemoteAddr().String() + return r.W.RemoteAddr().String() } return ip } // Post gets the (remote) Port of the client making the request. -func (s *State) Port() (string, error) { - _, port, err := net.SplitHostPort(s.W.RemoteAddr().String()) +func (r *Request) Port() string { + _, port, err := net.SplitHostPort(r.W.RemoteAddr().String()) if err != nil { - return "0", err + return "0" } - return port, nil + return port } // RemoteAddr returns the net.Addr of the client that sent the current request. -func (s *State) RemoteAddr() string { - return s.W.RemoteAddr().String() +func (r *Request) RemoteAddr() string { + return r.W.RemoteAddr().String() } // Proto gets the protocol used as the transport. This will be udp or tcp. -func (s *State) Proto() string { return Proto(s.W) } +func (r *Request) Proto() string { return Proto(r.W) } +// FIXME(miek): why not a method on Request // Proto gets the protocol used as the transport. This will be udp or tcp. func Proto(w dns.ResponseWriter) string { if _, ok := w.RemoteAddr().(*net.UDPAddr); ok { @@ -67,11 +69,10 @@ func Proto(w dns.ResponseWriter) string { return "udp" } -// Family returns the family of the transport. -// 1 for IPv4 and 2 for IPv6. -func (s *State) Family() int { +// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6. +func (r *Request) Family() int { var a net.IP - ip := s.W.RemoteAddr() + ip := r.W.RemoteAddr() if i, ok := ip.(*net.UDPAddr); ok { a = i.IP } @@ -86,48 +87,49 @@ func (s *State) Family() int { } // Do returns if the request has the DO (DNSSEC OK) bit set. -func (s *State) Do() bool { - if s.do != 0 { - return s.do == doTrue +func (r *Request) Do() bool { + if r.do != 0 { + return r.do == doTrue } - if o := s.Req.IsEdns0(); o != nil { + if o := r.Req.IsEdns0(); o != nil { if o.Do() { - s.do = doTrue + r.do = doTrue } else { - s.do = doFalse + r.do = doFalse } return o.Do() } - s.do = doFalse + r.do = doFalse return false } -// UDPSize returns if UDP buffer size advertised in the requests OPT record. +// Size returns if UDP buffer size advertised in the requests OPT record. // Or when the request was over TCP, we return the maximum allowed size of 64K. -func (s *State) Size() int { - if s.size != 0 { - return s.size +func (r *Request) Size() int { + if r.size != 0 { + return r.size } size := 0 - if o := s.Req.IsEdns0(); o != nil { + if o := r.Req.IsEdns0(); o != nil { if o.Do() == true { - s.do = doTrue + r.do = doTrue } else { - s.do = doFalse + r.do = doFalse } size = int(o.UDPSize()) } - size = edns0Size(s.Proto(), size) - s.size = size + // TODO(miek) move edns.Size to dnsutil? + size = edns.Size(r.Proto(), size) + r.size = size return size } -// SizeAndDo adds an OPT record that the reflects the intent from state. +// SizeAndDo adds an OPT record that the reflects the intent from request. // The returned bool indicated if an record was found and normalised. -func (s *State) SizeAndDo(m *dns.Msg) bool { - o := s.Req.IsEdns0() // TODO(miek): speed this up +func (r *Request) SizeAndDo(m *dns.Msg) bool { + o := r.Req.IsEdns0() // TODO(miek): speed this up if o == nil { return false } @@ -163,8 +165,8 @@ const ( // the TC bit will be set regardless of protocol, even TCP message will get the bit, the client // should then retry with pigeons. // TODO(referral). -func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { - size := s.Size() +func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) { + size := r.Size() l := reply.Len() if size >= l { return reply, ScrubIgnored @@ -173,7 +175,7 @@ func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { // If not delegation, drop additional section. reply.Extra = nil - s.SizeAndDo(reply) + r.SizeAndDo(reply) l = reply.Len() if size >= l { return reply, ScrubDone @@ -184,42 +186,42 @@ func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { } // Type returns the type of the question as a string. -func (s *State) Type() string { return dns.Type(s.Req.Question[0].Qtype).String() } +func (r *Request) Type() string { return dns.Type(r.Req.Question[0].Qtype).String() } // QType returns the type of the question as an uint16. -func (s *State) QType() uint16 { return s.Req.Question[0].Qtype } +func (r *Request) QType() uint16 { return r.Req.Question[0].Qtype } // Name returns the name of the question in the request. Note // this name will always have a closing dot and will be lower cased. After a call Name // the value will be cached. To clear this caching call Clear. -func (s *State) Name() string { - if s.name != "" { - return s.name +func (r *Request) Name() string { + if r.name != "" { + return r.name } - s.name = strings.ToLower(dns.Name(s.Req.Question[0].Name).String()) - return s.name + r.name = strings.ToLower(dns.Name(r.Req.Question[0].Name).String()) + return r.name } // QName returns the name of the question in the request. -func (s *State) QName() string { return dns.Name(s.Req.Question[0].Name).String() } +func (r *Request) QName() string { return dns.Name(r.Req.Question[0].Name).String() } // Class returns the class of the question in the request. -func (s *State) Class() string { return dns.Class(s.Req.Question[0].Qclass).String() } +func (r *Request) Class() string { return dns.Class(r.Req.Question[0].Qclass).String() } // QClass returns the class of the question in the request. -func (s *State) QClass() uint16 { return s.Req.Question[0].Qclass } +func (r *Request) QClass() uint16 { return r.Req.Question[0].Qclass } // ErrorMessage returns an error message suitable for sending // back to the client. -func (s *State) ErrorMessage(rcode int) *dns.Msg { +func (r *Request) ErrorMessage(rcode int) *dns.Msg { m := new(dns.Msg) - m.SetRcode(s.Req, rcode) + m.SetRcode(r.Req, rcode) return m } -// Clear clears all caching from State s. -func (s *State) Clear() { - s.name = "" +// Clear clears all caching from Request s. +func (r *Request) Clear() { + r.name = "" } const ( diff --git a/request/request_test.go b/request/request_test.go new file mode 100644 index 000000000..e49e9833f --- /dev/null +++ b/request/request_test.go @@ -0,0 +1,55 @@ +package request + +import ( + "testing" + + "github.com/miekg/coredns/middleware/test" + + "github.com/miekg/dns" +) + +func TestRequestDo(t *testing.T) { + st := testRequest() + + st.Do() + if st.do == 0 { + t.Fatalf("Expected st.do to be set") + } +} + +func TestRequestRemote(t *testing.T) { + st := testRequest() + if st.IP() != "10.240.0.1" { + t.Fatalf("Wrong IP from request") + } + p := st.Port() + if p == "" { + t.Fatalf("Failed to get Port from request") + } + if p != "40212" { + t.Fatalf("Wrong port from request") + } +} + +func BenchmarkRequestDo(b *testing.B) { + st := testRequest() + + for i := 0; i < b.N; i++ { + st.Do() + } +} + +func BenchmarkRequestSize(b *testing.B) { + st := testRequest() + + for i := 0; i < b.N; i++ { + st.Size() + } +} + +func testRequest() Request { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.SetEdns0(4097, true) + return Request{W: &test.ResponseWriter{}, Req: m} +} diff --git a/test/etcd_test.go b/test/etcd_test.go index 6645e9931..6f3c0b39a 100644 --- a/test/etcd_test.go +++ b/test/etcd_test.go @@ -9,11 +9,11 @@ import ( "testing" "time" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/etcd" "github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" etcdc "github.com/coreos/etcd/client" "github.com/miekg/dns" @@ -67,7 +67,7 @@ func TestEtcdStubAndProxyLookup(t *testing.T) { } p := proxy.New([]string{udp}) // use udp port from the server - state := middleware.State{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} resp, err := p.Lookup(state, "example.com.", dns.TypeA) if err != nil { t.Error("Expected to receive reply, but didn't") diff --git a/test/proxy_test.go b/test/proxy_test.go index 96a2c4c0d..e946f4354 100644 --- a/test/proxy_test.go +++ b/test/proxy_test.go @@ -5,9 +5,9 @@ import ( "log" "testing" - "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/test" + "github.com/miekg/coredns/request" "github.com/miekg/dns" ) @@ -46,7 +46,7 @@ func TestLookupProxy(t *testing.T) { log.SetOutput(ioutil.Discard) p := proxy.New([]string{udp}) - state := middleware.State{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} resp, err := p.Lookup(state, "example.org.", dns.TypeA) if err != nil { t.Fatal("Expected to receive reply, but didn't") diff --git a/test/server_test.go b/test/server_test.go index a03285bf3..903922a04 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -17,7 +17,7 @@ func TestProxyToChaosServer(t *testing.T) { t.Fatalf("could not get CoreDNS serving instance: %s", err) } - udpChaos, tcpChaos := CoreDNSServerPorts(chaos, 0) + udpChaos, _ := CoreDNSServerPorts(chaos, 0) defer chaos.Stop() corefileProxy := `.:0 { @@ -32,24 +32,23 @@ func TestProxyToChaosServer(t *testing.T) { udp, _ := CoreDNSServerPorts(proxy, 0) defer proxy.Stop() - chaosTest(t, udpChaos, "udp") - chaosTest(t, tcpChaos, "tcp") + chaosTest(t, udpChaos) - chaosTest(t, udp, "udp") + chaosTest(t, udp) // chaosTest(t, tcp, "tcp"), commented out because we use the original transport to reach the // proxy and we only forward to the udp port. } -func chaosTest(t *testing.T, server, net string) { +func chaosTest(t *testing.T, server string) { m := Msg("version.bind.", dns.TypeTXT, nil) m.Question[0].Qclass = dns.ClassCHAOS - r, err := Exchange(m, server, net) + r, err := dns.Exchange(m, server) if err != nil { t.Fatalf("Could not send message: %s", err) } if r.Rcode != dns.RcodeSuccess || len(r.Answer) == 0 { - t.Fatalf("Expected successful reply on %s, got %s", net, dns.RcodeToString[r.Rcode]) + t.Fatalf("Expected successful reply, got %s", dns.RcodeToString[r.Rcode]) } if r.Answer[0].String() != `version.bind. 0 CH TXT "CoreDNS-001"` { t.Fatalf("Expected version.bind. reply, got %s", r.Answer[0].String()) diff --git a/test/tests.go b/test/tests.go index 0f9d12bae..fb2853e9a 100644 --- a/test/tests.go +++ b/test/tests.go @@ -1,10 +1,6 @@ package test -import ( - "github.com/miekg/coredns/middleware" - - "github.com/miekg/dns" -) +import "github.com/miekg/dns" func Msg(zone string, typ uint16, o *dns.OPT) *dns.Msg { m := new(dns.Msg) @@ -14,9 +10,3 @@ func Msg(zone string, typ uint16, o *dns.OPT) *dns.Msg { } return m } - -func Exchange(m *dns.Msg, server, net string) (*dns.Msg, error) { - c := new(dns.Client) - c.Net = net - return middleware.Exchange(c, m, server) -}