Cleanup: put middleware helper functions in pkgs (#245)

Move all (almost all) Go files in middleware into their
own packages. This makes for better naming and discoverability.

Lot of changes elsewhere to make this change.

The middleware.State was renamed to request.Request which is better,
but still does not cover all use-cases. It was also moved out middleware
because it is used by `dnsserver` as well.

A pkg/dnsutil packages was added for shared, handy, dns util functions.

All normalize functions are now put in normalize.go
This commit is contained in:
Miek Gieben 2016-09-07 11:10:16 +01:00 committed by GitHub
parent 684330fd28
commit d1f17fa7e0
90 changed files with 680 additions and 1037 deletions

View file

@ -7,7 +7,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/edns"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -163,7 +164,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
}() }()
if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once. if m, err := edns.Version(r); err != nil { // Wrong EDNS version, return at once.
w.WriteMsg(m) w.WriteMsg(m)
return return
} }
@ -214,10 +215,11 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// DefaultErrorFunc responds to an DNS request with an error. // DefaultErrorFunc responds to an DNS request with an error.
func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
answer := new(dns.Msg) answer := new(dns.Msg)
answer.SetRcode(r, rcode) answer.SetRcode(r, rcode)
state.SizeAndDo(answer) state.SizeAndDo(answer)
w.WriteMsg(answer) w.WriteMsg(answer)

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/response"
"github.com/miekg/dns" "github.com/miekg/dns"
gcache "github.com/patrickmn/go-cache" gcache "github.com/patrickmn/go-cache"
@ -23,7 +24,7 @@ func NewCache(ttl int, zones []string, next middleware.Handler) Cache {
return Cache{Next: next, Zones: zones, cache: gcache.New(defaultDuration, purgeDuration), cap: time.Duration(ttl) * time.Second} return Cache{Next: next, Zones: zones, cache: gcache.New(defaultDuration, purgeDuration), cap: time.Duration(ttl) * time.Second}
} }
func cacheKey(m *dns.Msg, t middleware.MsgType, do bool) string { func cacheKey(m *dns.Msg, t response.Type, do bool) string {
if m.Truncated { if m.Truncated {
return "" return ""
} }
@ -31,15 +32,15 @@ func cacheKey(m *dns.Msg, t middleware.MsgType, do bool) string {
qtype := m.Question[0].Qtype qtype := m.Question[0].Qtype
qname := strings.ToLower(m.Question[0].Name) qname := strings.ToLower(m.Question[0].Name)
switch t { switch t {
case middleware.Success: case response.Success:
fallthrough fallthrough
case middleware.Delegation: case response.Delegation:
return successKey(qname, qtype, do) return successKey(qname, qtype, do)
case middleware.NameError: case response.NameError:
return nameErrorKey(qname, do) return nameErrorKey(qname, do)
case middleware.NoData: case response.NoData:
return noDataKey(qname, qtype, do) return noDataKey(qname, qtype, do)
case middleware.OtherError: case response.OtherError:
return "" return ""
} }
return "" return ""
@ -57,7 +58,7 @@ func NewCachingResponseWriter(w dns.ResponseWriter, cache *gcache.Cache, cap tim
func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error { func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error {
do := false do := false
mt, opt := middleware.Classify(res) mt, opt := response.Classify(res)
if opt != nil { if opt != nil {
do = opt.Do() do = opt.Do()
} }
@ -72,7 +73,7 @@ func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error {
return c.ResponseWriter.WriteMsg(res) return c.ResponseWriter.WriteMsg(res)
} }
func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt middleware.MsgType) { func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt response.Type) {
if key == "" { if key == "" {
log.Printf("[ERROR] Caching called with empty cache key") log.Printf("[ERROR] Caching called with empty cache key")
return return
@ -80,21 +81,21 @@ func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt middleware.MsgTyp
duration := c.cap duration := c.cap
switch mt { switch mt {
case middleware.Success, middleware.Delegation: case response.Success, response.Delegation:
if c.cap == 0 { if c.cap == 0 {
duration = minTtl(m.Answer, mt) duration = minTtl(m.Answer, mt)
} }
i := newItem(m, duration) i := newItem(m, duration)
c.cache.Set(key, i, duration) c.cache.Set(key, i, duration)
case middleware.NameError, middleware.NoData: case response.NameError, response.NoData:
if c.cap == 0 { if c.cap == 0 {
duration = minTtl(m.Ns, mt) duration = minTtl(m.Ns, mt)
} }
i := newItem(m, duration) i := newItem(m, duration)
c.cache.Set(key, i, duration) c.cache.Set(key, i, duration)
case middleware.OtherError: case response.OtherError:
// don't cache these // don't cache these
default: default:
log.Printf("[WARNING] Caching called with unknown middleware MsgType: %d", mt) log.Printf("[WARNING] Caching called with unknown middleware MsgType: %d", mt)
@ -112,19 +113,19 @@ func (c *CachingResponseWriter) Hijack() {
return return
} }
func minTtl(rrs []dns.RR, mt middleware.MsgType) time.Duration { func minTtl(rrs []dns.RR, mt response.Type) time.Duration {
if mt != middleware.Success && mt != middleware.NameError && mt != middleware.NoData { if mt != response.Success && mt != response.NameError && mt != response.NoData {
return 0 return 0
} }
minTtl := maxTtl minTtl := maxTtl
for _, r := range rrs { for _, r := range rrs {
switch mt { switch mt {
case middleware.NameError, middleware.NoData: case response.NameError, response.NoData:
if r.Header().Rrtype == dns.TypeSOA { if r.Header().Rrtype == dns.TypeSOA {
return time.Duration(r.(*dns.SOA).Minttl) * time.Second return time.Duration(r.(*dns.SOA).Minttl) * time.Second
} }
case middleware.Success, middleware.Delegation: case response.Success, response.Delegation:
if r.Header().Ttl < minTtl { if r.Header().Ttl < minTtl {
minTtl = r.Header().Ttl minTtl = r.Header().Ttl
} }

View file

@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/response"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -78,7 +79,7 @@ func TestCache(t *testing.T) {
m = cacheMsg(m, tc) m = cacheMsg(m, tc)
do := tc.in.Do do := tc.in.Do
mt, _ := middleware.Classify(m) mt, _ := response.Classify(m)
key := cacheKey(m, mt, do) key := cacheKey(m, mt, do)
crr.set(m, key, mt) crr.set(m, key, mt)

View file

@ -2,6 +2,7 @@ package cache
import ( import (
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -10,7 +11,7 @@ import (
// ServeDNS implements the middleware.Handler interface. // ServeDNS implements the middleware.Handler interface.
func (c Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (c Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
qname := state.Name() qname := state.Name()
qtype := state.QType() qtype := state.QType()

View file

@ -4,6 +4,7 @@ import (
"os" "os"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -18,7 +19,7 @@ type Chaos struct {
} }
func (c Chaos) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (c Chaos) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
if state.QClass() != dns.ClassCHAOS || state.QType() != dns.TypeTXT { if state.QClass() != dns.ClassCHAOS || state.QType() != dns.TypeTXT {
return c.Next.ServeDNS(ctx, w, r) return c.Next.ServeDNS(ctx, w, r)
} }

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -58,7 +59,7 @@ func TestChaos(t *testing.T) {
req.Question[0].Qclass = dns.ClassCHAOS req.Question[0].Qclass = dns.ClassCHAOS
em.Next = tc.next em.Next = tc.next
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
code, err := em.ServeDNS(ctx, rec, req) code, err := em.ServeDNS(ctx, rec, req)
if err != tc.expectedErr { if err != tc.expectedErr {
@ -68,7 +69,7 @@ func TestChaos(t *testing.T) {
t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code)
} }
if tc.expectedReply != "" { if tc.expectedReply != "" {
answer := rec.Msg().Answer[0].(*dns.TXT).Txt[0] answer := rec.Msg.Answer[0].(*dns.TXT).Txt[0]
if answer != tc.expectedReply { if answer != tc.expectedReply {
t.Errorf("Test %d: Expected answer %s, but got %s", i, tc.expectedReply, answer) t.Errorf("Test %d: Expected answer %s, but got %s", i, tc.expectedReply, answer)
} }

View file

@ -4,8 +4,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -16,7 +16,7 @@ func TestZoneSigningBlackLies(t *testing.T) {
defer rm2() defer rm2()
m := testNxdomainMsg() m := testNxdomainMsg()
state := middleware.State{Req: m} state := request.Request{Req: m}
m = d.Sign(state, "miek.nl.", time.Now().UTC()) m = d.Sign(state, "miek.nl.", time.Now().UTC())
if !section(m.Ns, 2) { if !section(m.Ns, 2) {
t.Errorf("authority section should have 2 sig") t.Errorf("authority section should have 2 sig")

View file

@ -4,8 +4,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
) )
func TestCacheSet(t *testing.T) { func TestCacheSet(t *testing.T) {
@ -20,7 +20,7 @@ func TestCacheSet(t *testing.T) {
} }
m := testMsg() m := testMsg()
state := middleware.State{Req: m} state := request.Request{Req: m}
k := key(m.Answer) // calculate *before* we add the sig k := key(m.Answer) // calculate *before* we add the sig
d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil) d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil)
m = d.Sign(state, "miek.nl.", time.Now().UTC()) m = d.Sign(state, "miek.nl.", time.Now().UTC())

View file

@ -8,7 +8,7 @@ import (
"os" "os"
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -50,7 +50,7 @@ func ParseKeyFile(pubFile, privFile string) (*DNSKEY, error) {
} }
// getDNSKEY returns the correct DNSKEY to the client. Signatures are added when do is true. // getDNSKEY returns the correct DNSKEY to the client. Signatures are added when do is true.
func (d Dnssec) getDNSKEY(state middleware.State, zone string, do bool) *dns.Msg { func (d Dnssec) getDNSKEY(state request.Request, zone string, do bool) *dns.Msg {
keys := make([]dns.RR, len(d.keys)) keys := make([]dns.RR, len(d.keys))
for i, k := range d.keys { for i, k := range d.keys {
keys[i] = dns.Copy(k.K) keys[i] = dns.Copy(k.K)

View file

@ -4,7 +4,9 @@ import (
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/singleflight" "github.com/miekg/coredns/middleware/pkg/response"
"github.com/miekg/coredns/middleware/pkg/singleflight"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
gcache "github.com/patrickmn/go-cache" gcache "github.com/patrickmn/go-cache"
@ -28,20 +30,21 @@ func New(zones []string, keys []*DNSKEY, next middleware.Handler) Dnssec {
} }
} }
// Sign signs the message m. it takes care of negative or nodata responses. It // Sign signs the message in state. it takes care of negative or nodata responses. It
// uses NSEC black lies for authenticated denial of existence. Signatures // uses NSEC black lies for authenticated denial of existence. Signatures
// creates will be cached for a short while. By default we sign for 8 days, // creates will be cached for a short while. By default we sign for 8 days,
// starting 3 hours ago. // starting 3 hours ago.
func (d Dnssec) Sign(state middleware.State, zone string, now time.Time) *dns.Msg { func (d Dnssec) Sign(state request.Request, zone string, now time.Time) *dns.Msg {
req := state.Req req := state.Req
mt, _ := middleware.Classify(req) // TODO(miek): need opt record here?
if mt == middleware.Delegation { mt, _ := response.Classify(req) // TODO(miek): need opt record here?
if mt == response.Delegation {
return req return req
} }
incep, expir := incepExpir(now) incep, expir := incepExpir(now)
if mt == middleware.NameError { if mt == response.NameError {
if req.Ns[0].Header().Rrtype != dns.TypeSOA || len(req.Ns) > 1 { if req.Ns[0].Header().Rrtype != dns.TypeSOA || len(req.Ns) > 1 {
return req return req
} }

View file

@ -4,8 +4,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -16,7 +16,7 @@ func TestZoneSigning(t *testing.T) {
defer rm2() defer rm2()
m := testMsg() m := testMsg()
state := middleware.State{Req: m} state := request.Request{Req: m}
m = d.Sign(state, "miek.nl.", time.Now().UTC()) m = d.Sign(state, "miek.nl.", time.Now().UTC())
if !section(m.Answer, 1) { if !section(m.Answer, 1) {
@ -44,7 +44,7 @@ func TestZoneSigningDouble(t *testing.T) {
d.keys = append(d.keys, key1) d.keys = append(d.keys, key1)
m := testMsg() m := testMsg()
state := middleware.State{Req: m} state := request.Request{Req: m}
m = d.Sign(state, "miek.nl.", time.Now().UTC()) m = d.Sign(state, "miek.nl.", time.Now().UTC())
if !section(m.Answer, 2) { if !section(m.Answer, 2) {
t.Errorf("answer section should have 1 sig") t.Errorf("answer section should have 1 sig")
@ -68,7 +68,7 @@ func TestSigningDifferentZone(t *testing.T) {
} }
m := testMsgEx() m := testMsgEx()
state := middleware.State{Req: m} state := request.Request{Req: m}
d := New([]string{"example.org."}, []*DNSKEY{key}, nil) d := New([]string{"example.org."}, []*DNSKEY{key}, nil)
m = d.Sign(state, "example.org.", time.Now().UTC()) m = d.Sign(state, "example.org.", time.Now().UTC())
if !section(m.Answer, 1) { if !section(m.Answer, 1) {
@ -86,7 +86,7 @@ func TestSigningCname(t *testing.T) {
defer rm2() defer rm2()
m := testMsgCname() m := testMsgCname()
state := middleware.State{Req: m} state := request.Request{Req: m}
m = d.Sign(state, "miek.nl.", time.Now().UTC()) m = d.Sign(state, "miek.nl.", time.Now().UTC())
if !section(m.Answer, 1) { if !section(m.Answer, 1) {
t.Errorf("answer section should have 1 sig") t.Errorf("answer section should have 1 sig")
@ -100,7 +100,7 @@ func TestZoneSigningDelegation(t *testing.T) {
defer rm2() defer rm2()
m := testDelegationMsg() m := testDelegationMsg()
state := middleware.State{Req: m} state := request.Request{Req: m}
m = d.Sign(state, "miek.nl.", time.Now().UTC()) m = d.Sign(state, "miek.nl.", time.Now().UTC())
if !section(m.Ns, 0) { if !section(m.Ns, 0) {
t.Errorf("authority section should have 0 sig") t.Errorf("authority section should have 0 sig")

View file

@ -2,6 +2,7 @@ package dnssec
import ( import (
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -10,7 +11,7 @@ import (
// ServeDNS implements the middleware.Handler interface. // ServeDNS implements the middleware.Handler interface.
func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
do := state.Do() do := state.Do()
qname := state.Name() qname := state.Name()

View file

@ -5,8 +5,8 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/file" "github.com/miekg/coredns/middleware/file"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -83,14 +83,14 @@ func TestLookupZone(t *testing.T) {
for _, tc := range dnsTestCases { for _, tc := range dnsTestCases {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := dh.ServeDNS(ctx, rec, m) _, err := dh.ServeDNS(ctx, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
return return
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))
@ -121,14 +121,14 @@ func TestLookupDNSKEY(t *testing.T) {
for _, tc := range dnssecTestCases { for _, tc := range dnssecTestCases {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := dh.ServeDNS(ctx, rec, m) _, err := dh.ServeDNS(ctx, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
return return
} }
resp := rec.Msg()
resp := rec.Msg
if !resp.Authoritative { if !resp.Authoritative {
t.Errorf("Authoritative Answer should be true, got false") t.Errorf("Authoritative Answer should be true, got false")
} }

View file

@ -5,6 +5,8 @@ import (
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -20,7 +22,7 @@ func NewDnssecResponseWriter(w dns.ResponseWriter, d Dnssec) *DnssecResponseWrit
func (d *DnssecResponseWriter) WriteMsg(res *dns.Msg) error { func (d *DnssecResponseWriter) WriteMsg(res *dns.Msg) error {
// By definition we should sign anything that comes back, we should still figure out for // By definition we should sign anything that comes back, we should still figure out for
// which zone it should be. // which zone it should be.
state := middleware.State{W: d.ResponseWriter, Req: res} state := request.Request{W: d.ResponseWriter, Req: res}
qname := state.Name() qname := state.Name()
zone := middleware.Zones(d.d.zones).Matches(qname) zone := middleware.Zones(d.d.zones).Matches(qname)

View file

@ -9,6 +9,8 @@ import (
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/roller"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -19,7 +21,7 @@ type ErrorHandler struct {
Next middleware.Handler Next middleware.Handler
LogFile string LogFile string
Log *log.Logger Log *log.Logger
LogRoller *middleware.LogRoller LogRoller *roller.LogRoller
Debug bool // if true, errors are written out to client rather than to a log Debug bool // if true, errors are written out to client rather than to a log
} }
@ -29,7 +31,7 @@ func (h ErrorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
rcode, err := h.Next.ServeDNS(ctx, w, r) rcode, err := h.Next.ServeDNS(ctx, w, r)
if err != nil { if err != nil {
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, state.Name(), state.Type(), err) errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, state.Name(), state.Type(), err)
if h.Debug { if h.Debug {
@ -53,7 +55,7 @@ func (h ErrorHandler) recovery(ctx context.Context, w dns.ResponseWriter, r *dns
return return
} }
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
// Obtain source of panic // Obtain source of panic
// From: https://gist.github.com/swdunlop/9629168 // From: https://gist.github.com/swdunlop/9629168
var name, file string // function name, file name var name, file string // function name, file name

View file

@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -47,7 +48,7 @@ func TestErrors(t *testing.T) {
for i, tc := range tests { for i, tc := range tests {
em.Next = tc.next em.Next = tc.next
buf.Reset() buf.Reset()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
code, err := em.ServeDNS(ctx, rec, req) code, err := em.ServeDNS(ctx, rec, req)
if err != tc.expectedErr { if err != tc.expectedErr {
@ -78,7 +79,7 @@ func TestVisibleErrorWithPanic(t *testing.T) {
req := new(dns.Msg) req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA) req.SetQuestion("example.org.", dns.TypeA)
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
code, err := eh.ServeDNS(ctx, rec, req) code, err := eh.ServeDNS(ctx, rec, req)
if code != 0 { if code != 0 {

View file

@ -6,7 +6,7 @@ import (
"os" "os"
"github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/roller"
"github.com/hashicorp/go-syslog" "github.com/hashicorp/go-syslog"
"github.com/mholt/caddy" "github.com/mholt/caddy"
@ -93,7 +93,7 @@ func errorsParse(c *caddy.Controller) (ErrorHandler, error) {
if c.NextArg() { if c.NextArg() {
if c.Val() == "{" { if c.Val() == "{" {
c.IncrNest() c.IncrNest()
logRoller, err := middleware.ParseRoller(c) logRoller, err := roller.Parse(c)
if err != nil { if err != nil {
return hadBlock, err return hadBlock, err
} }

View file

@ -4,7 +4,7 @@ import (
"testing" "testing"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/roller"
) )
func TestErrorsParse(t *testing.T) { func TestErrorsParse(t *testing.T) {
@ -29,7 +29,7 @@ func TestErrorsParse(t *testing.T) {
}}, }},
{`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, ErrorHandler{ {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, ErrorHandler{
LogFile: "errors.txt", LogFile: "errors.txt",
LogRoller: &middleware.LogRoller{ LogRoller: &roller.LogRoller{
MaxSize: 2, MaxSize: 2,
MaxAge: 10, MaxAge: 10,
MaxBackups: 3, MaxBackups: 3,
@ -43,7 +43,7 @@ func TestErrorsParse(t *testing.T) {
} }
}`, false, ErrorHandler{ }`, false, ErrorHandler{
LogFile: "errors.txt", LogFile: "errors.txt",
LogRoller: &middleware.LogRoller{ LogRoller: &roller.LogRoller{
MaxSize: 3, MaxSize: 3,
MaxAge: 11, MaxAge: 11,
MaxBackups: 5, MaxBackups: 5,

View file

@ -7,8 +7,8 @@ package etcd
import ( import (
"testing" "testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -25,14 +25,14 @@ func TestCnameLookup(t *testing.T) {
for _, tc := range dnsTestCasesCname { for _, tc := range dnsTestCasesCname {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m) _, err := etc.ServeDNS(ctxt, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
return return
} }
resp := rec.Msg()
resp := rec.Msg
if !test.Header(t, tc, resp) { if !test.Header(t, tc, resp) {
t.Logf("%v\n", resp) t.Logf("%v\n", resp)
continue continue

View file

@ -7,8 +7,8 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -41,14 +41,14 @@ func TestDebugLookup(t *testing.T) {
for _, tc := range dnsTestCasesDebug { for _, tc := range dnsTestCasesDebug {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m) _, err := etc.ServeDNS(ctxt, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
continue continue
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))
@ -79,14 +79,14 @@ func TestDebugLookupFalse(t *testing.T) {
for _, tc := range dnsTestCasesDebugFalse { for _, tc := range dnsTestCasesDebugFalse {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m) _, err := etc.ServeDNS(ctxt, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
continue continue
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))

View file

@ -9,8 +9,8 @@ import (
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/singleflight"
"github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/singleflight"
etcdc "github.com/coreos/etcd/client" etcdc "github.com/coreos/etcd/client"
"golang.org/x/net/context" "golang.org/x/net/context"

View file

@ -6,8 +6,8 @@ import (
"sort" "sort"
"testing" "testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -23,14 +23,14 @@ func TestGroupLookup(t *testing.T) {
for _, tc := range dnsTestCasesGroup { for _, tc := range dnsTestCasesGroup {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m) _, err := etc.ServeDNS(ctxt, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
continue continue
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))

View file

@ -5,6 +5,7 @@ import (
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -12,7 +13,7 @@ import (
func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
opt := Options{} opt := Options{}
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
if state.QClass() != dns.ClassINET { if state.QClass() != dns.ClassINET {
return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET") return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET")
} }
@ -115,7 +116,7 @@ func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
} }
// Err write an error response to the client. // Err write an error response to the client.
func (e *Etcd) Err(zone string, rcode int, state middleware.State, debug []msg.Service, err error, opt Options) (int, error) { func (e *Etcd) Err(zone string, rcode int, state request.Request, debug []msg.Service, err error, opt Options) (int, error) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetRcode(state.Req, rcode) m.SetRcode(state.Req, rcode)
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true

View file

@ -8,6 +8,8 @@ import (
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsutil"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -16,7 +18,7 @@ type Options struct {
Debug string Debug string
} }
func (e Etcd) records(state middleware.State, exact bool, opt Options) (services, debug []msg.Service, err error) { func (e Etcd) records(state request.Request, exact bool, opt Options) (services, debug []msg.Service, err error) {
services, err = e.Records(state.Name(), exact) services, err = e.Records(state.Name(), exact)
if err != nil { if err != nil {
return return
@ -28,7 +30,7 @@ func (e Etcd) records(state middleware.State, exact bool, opt Options) (services
return return
} }
func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) { func (e Etcd) A(zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, false, opt) services, debug, err := e.records(state, false, opt)
if err != nil { if err != nil {
return nil, debug, err return nil, debug, err
@ -49,11 +51,11 @@ func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, o
// don't add it, and just continue // don't add it, and just continue
continue continue
} }
if isDuplicateCNAME(newRecord, previousRecords) { if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue continue
} }
state1 := copyState(state, serv.Host, state.QType()) state1 := state.NewWithQuestion(serv.Host, state.QType())
nextRecords, nextDebug, err := e.A(zone, state1, append(previousRecords, newRecord), opt) nextRecords, nextDebug, err := e.A(zone, state1, append(previousRecords, newRecord), opt)
if err == nil { if err == nil {
@ -90,7 +92,7 @@ func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, o
return records, debug, nil return records, debug, nil
} }
func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) { func (e Etcd) AAAA(zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, false, opt) services, debug, err := e.records(state, false, opt)
if err != nil { if err != nil {
return nil, debug, err return nil, debug, err
@ -111,11 +113,11 @@ func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR
// don't add it, and just continue // don't add it, and just continue
continue continue
} }
if isDuplicateCNAME(newRecord, previousRecords) { if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue continue
} }
state1 := copyState(state, serv.Host, state.QType()) state1 := state.NewWithQuestion(serv.Host, state.QType())
nextRecords, nextDebug, err := e.AAAA(zone, state1, append(previousRecords, newRecord), opt) nextRecords, nextDebug, err := e.AAAA(zone, state1, append(previousRecords, newRecord), opt)
if err == nil { if err == nil {
@ -155,7 +157,7 @@ func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR
// SRV returns SRV records from etcd. // SRV returns SRV records from etcd.
// If the Target is not a name but an IP address, a name is created on the fly. // If the Target is not a name but an IP address, a name is created on the fly.
func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { func (e Etcd) SRV(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, false, opt) services, debug, err := e.records(state, false, opt)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
@ -220,7 +222,7 @@ func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, ex
} }
// Internal name, we should have some info on them, either v4 or v6 // Internal name, we should have some info on them, either v4 or v6
// Clients expect a complete answer, because we are a recursor in their view. // Clients expect a complete answer, because we are a recursor in their view.
state1 := copyState(state, srv.Target, dns.TypeA) state1 := state.NewWithQuestion(srv.Target, dns.TypeA)
addr, debugAddr, e1 := e.A(zone, state1, nil, opt) addr, debugAddr, e1 := e.A(zone, state1, nil, opt)
if e1 == nil { if e1 == nil {
extra = append(extra, addr...) extra = append(extra, addr...)
@ -246,7 +248,7 @@ func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, ex
// MX returns MX records from etcd. // MX returns MX records from etcd.
// If the Target is not a name but an IP address, a name is created on the fly. // If the Target is not a name but an IP address, a name is created on the fly.
func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { func (e Etcd) MX(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, false, opt) services, debug, err := e.records(state, false, opt)
if err != nil { if err != nil {
return nil, nil, debug, err return nil, nil, debug, err
@ -291,7 +293,7 @@ func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, ext
break break
} }
// Internal name // Internal name
state1 := copyState(state, mx.Mx, dns.TypeA) state1 := state.NewWithQuestion(mx.Mx, dns.TypeA)
addr, debugAddr, e1 := e.A(zone, state1, nil, opt) addr, debugAddr, e1 := e.A(zone, state1, nil, opt)
if e1 == nil { if e1 == nil {
extra = append(extra, addr...) extra = append(extra, addr...)
@ -311,7 +313,7 @@ func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, ext
return records, extra, debug, nil return records, extra, debug, nil
} }
func (e Etcd) CNAME(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) { func (e Etcd) CNAME(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, true, opt) services, debug, err := e.records(state, true, opt)
if err != nil { if err != nil {
return nil, debug, err return nil, debug, err
@ -327,7 +329,7 @@ func (e Etcd) CNAME(zone string, state middleware.State, opt Options) (records [
} }
// PTR returns the PTR records, only services that have a domain name as host are included. // PTR returns the PTR records, only services that have a domain name as host are included.
func (e Etcd) PTR(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) { func (e Etcd) PTR(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, true, opt) services, debug, err := e.records(state, true, opt)
if err != nil { if err != nil {
return nil, debug, err return nil, debug, err
@ -341,7 +343,7 @@ func (e Etcd) PTR(zone string, state middleware.State, opt Options) (records []d
return records, debug, nil return records, debug, nil
} }
func (e Etcd) TXT(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) { func (e Etcd) TXT(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, false, opt) services, debug, err := e.records(state, false, opt)
if err != nil { if err != nil {
return nil, debug, err return nil, debug, err
@ -356,7 +358,7 @@ func (e Etcd) TXT(zone string, state middleware.State, opt Options) (records []d
return records, debug, nil return records, debug, nil
} }
func (e Etcd) NS(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) { func (e Etcd) NS(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) {
// NS record for this zone live in a special place, ns.dns.<zone>. Fake our lookup. // NS record for this zone live in a special place, ns.dns.<zone>. Fake our lookup.
// only a tad bit fishy... // only a tad bit fishy...
old := state.QName() old := state.QName()
@ -389,7 +391,7 @@ func (e Etcd) NS(zone string, state middleware.State, opt Options) (records, ext
} }
// SOA Record returns a SOA record. // SOA Record returns a SOA record.
func (e Etcd) SOA(zone string, state middleware.State, opt Options) ([]dns.RR, []msg.Service, error) { func (e Etcd) SOA(zone string, state request.Request, opt Options) ([]dns.RR, []msg.Service, error) {
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: 300, Class: dns.ClassINET} header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: 300, Class: dns.ClassINET}
soa := &dns.SOA{Hdr: header, soa := &dns.SOA{Hdr: header,
@ -404,21 +406,3 @@ func (e Etcd) SOA(zone string, state middleware.State, opt Options) ([]dns.RR, [
// TODO(miek): fake some msg.Service here when returning. // TODO(miek): fake some msg.Service here when returning.
return []dns.RR{soa}, nil, nil return []dns.RR{soa}, nil, nil
} }
func isDuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
for _, rec := range records {
if v, ok := rec.(*dns.CNAME); ok {
if v.Target == r.Target {
return true
}
}
}
return false
}
// TODO(miek): Move to middleware?
func copyState(state middleware.State, target string, typ uint16) middleware.State {
state1 := middleware.State{W: state.W, Req: state.Req.Copy()}
state1.Req.Question[0] = dns.Question{Name: dns.Fqdn(target), Qclass: dns.ClassINET, Qtype: typ}
return state1
}

View file

@ -6,8 +6,8 @@ import (
"sort" "sort"
"testing" "testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -25,14 +25,14 @@ func TestMultiLookup(t *testing.T) {
for _, tc := range dnsTestCasesMulti { for _, tc := range dnsTestCasesMulti {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m) _, err := etc.ServeDNS(ctxt, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
return return
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))

View file

@ -10,8 +10,8 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -27,14 +27,14 @@ func TestOtherLookup(t *testing.T) {
for _, tc := range dnsTestCasesOther { for _, tc := range dnsTestCasesOther {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m) _, err := etc.ServeDNS(ctxt, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
continue continue
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))

View file

@ -6,8 +6,8 @@ import (
"sort" "sort"
"testing" "testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
@ -27,14 +27,14 @@ func TestProxyLookupFailDebug(t *testing.T) {
for _, tc := range dnsTestCasesProxy { for _, tc := range dnsTestCasesProxy {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m) _, err := etc.ServeDNS(ctxt, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
continue continue
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))

View file

@ -10,8 +10,8 @@ import (
"github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/singleflight"
"github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/singleflight"
etcdc "github.com/coreos/etcd/client" etcdc "github.com/coreos/etcd/client"
"github.com/mholt/caddy" "github.com/mholt/caddy"
@ -70,7 +70,7 @@ func etcdParse(c *caddy.Controller) (*Etcd, bool, error) {
etc.Zones = make([]string, len(c.ServerBlockKeys)) etc.Zones = make([]string, len(c.ServerBlockKeys))
copy(etc.Zones, c.ServerBlockKeys) copy(etc.Zones, c.ServerBlockKeys)
} }
middleware.Zones(etc.Zones).FullyQualify() middleware.Zones(etc.Zones).Normalize()
if c.NextBlock() { if c.NextBlock() {
// TODO(miek): 2 switches? // TODO(miek): 2 switches?
switch c.Val() { switch c.Val() {

View file

@ -8,11 +8,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/pkg/singleflight"
"github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/singleflight"
etcdc "github.com/coreos/etcd/client" etcdc "github.com/coreos/etcd/client"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -65,14 +65,14 @@ func TestLookup(t *testing.T) {
for _, tc := range dnsTestCases { for _, tc := range dnsTestCases {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m) _, err := etc.ServeDNS(ctxt, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got: %v for %s %s\n", err, m.Question[0].Name, dns.Type(m.Question[0].Qtype)) t.Errorf("expected no error, got: %v for %s %s\n", err, m.Question[0].Name, dns.Type(m.Question[0].Qtype))
return return
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))

View file

@ -4,7 +4,7 @@ import (
"errors" "errors"
"log" "log"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -27,7 +27,7 @@ func (s Stub) ServeDNS(ctx context.Context, w dns.ResponseWriter, req *dns.Msg)
return dns.RcodeServerFailure, nil return dns.RcodeServerFailure, nil
} }
state := middleware.State{W: w, Req: req} state := request.Request{W: w, Req: req}
m, e := proxy.Forward(state) m, e := proxy.Forward(state)
if e != nil { if e != nil {
return dns.RcodeServerFailure, e return dns.RcodeServerFailure, e

View file

@ -8,8 +8,8 @@ import (
"strconv" "strconv"
"testing" "testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -53,7 +53,7 @@ func TestStubLookup(t *testing.T) {
for _, tc := range dnsTestCasesStub { for _, tc := range dnsTestCasesStub {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m) _, err := etc.ServeDNS(ctxt, rec, m)
if err != nil && m.Question[0].Name == "example.org." { if err != nil && m.Question[0].Name == "example.org." {
// This is OK, we expect this backend to *not* work. // This is OK, we expect this backend to *not* work.
@ -62,12 +62,11 @@ func TestStubLookup(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("expected no error, got %v for %s\n", err, m.Question[0].Name) t.Errorf("expected no error, got %v for %s\n", err, m.Question[0].Name)
} }
resp := rec.Msg() resp := rec.Msg
if resp == nil { if resp == nil {
// etcd not running? // etcd not running?
continue continue
} }
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))

View file

@ -1,8 +0,0 @@
package middleware
import "github.com/miekg/dns"
func Exchange(c *dns.Client, m *dns.Msg, server string) (*dns.Msg, error) {
r, _, err := c.Exchange(m, server)
return r, err
}

View file

@ -5,7 +5,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -57,14 +57,14 @@ func TestLookupDelegation(t *testing.T) {
for _, tc := range delegationTestCases { for _, tc := range delegationTestCases {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := fm.ServeDNS(ctx, rec, m) _, err := fm.ServeDNS(ctx, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
return return
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))

View file

@ -5,7 +5,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -116,14 +116,14 @@ func TestLookupDNSSEC(t *testing.T) {
for _, tc := range dnssecTestCases { for _, tc := range dnssecTestCases {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := fm.ServeDNS(ctx, rec, m) _, err := fm.ServeDNS(ctx, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
return return
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))
@ -154,7 +154,7 @@ func BenchmarkLookupDNSSEC(b *testing.B) {
fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}}
ctx := context.TODO() ctx := context.TODO()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
tc := test.Case{ tc := test.Case{
Qname: "b.miek.nl.", Qtype: dns.TypeA, Do: true, Qname: "b.miek.nl.", Qtype: dns.TypeA, Do: true,

View file

@ -5,7 +5,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -43,14 +43,14 @@ func TestLookupENT(t *testing.T) {
for _, tc := range entTestCases { for _, tc := range entTestCases {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := fm.ServeDNS(ctx, rec, m) _, err := fm.ServeDNS(ctx, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
return return
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))

View file

@ -6,6 +6,7 @@ import (
"log" "log"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -24,7 +25,7 @@ type (
) )
func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
if state.QClass() != dns.ClassINET { if state.QClass() != dns.ClassINET {
return dns.RcodeServerFailure, errors.New("can only deal with ClassINET") return dns.RcodeServerFailure, errors.New("can only deal with ClassINET")

View file

@ -5,7 +5,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -87,14 +87,14 @@ func TestLookup(t *testing.T) {
for _, tc := range dnsTestCases { for _, tc := range dnsTestCases {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := fm.ServeDNS(ctx, rec, m) _, err := fm.ServeDNS(ctx, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
return return
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))
@ -122,7 +122,7 @@ func TestLookupNil(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
m := dnsTestCases[0].Msg() m := dnsTestCases[0].Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
fm.ServeDNS(ctx, rec, m) fm.ServeDNS(ctx, rec, m)
} }
@ -134,7 +134,7 @@ func BenchmarkLookup(b *testing.B) {
fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}}
ctx := context.TODO() ctx := context.TODO()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
tc := test.Case{ tc := test.Case{
Qname: "www.miek.nl.", Qtype: dns.TypeA, Qname: "www.miek.nl.", Qtype: dns.TypeA,

View file

@ -5,6 +5,7 @@ import (
"log" "log"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -12,7 +13,7 @@ import (
// isNotify checks if state is a notify message and if so, will *also* check if it // isNotify checks if state is a notify message and if so, will *also* check if it
// is from one of the configured masters. If not it will not be a valid notify // is from one of the configured masters. If not it will not be a valid notify
// message. If the zone z is not a secondary zone the message will also be ignored. // message. If the zone z is not a secondary zone the message will also be ignored.
func (z *Zone) isNotify(state middleware.State) bool { func (z *Zone) isNotify(state request.Request) bool {
if state.Req.Opcode != dns.OpcodeNotify { if state.Req.Opcode != dns.OpcodeNotify {
return false return false
} }
@ -56,7 +57,7 @@ func notify(zone string, to []string) error {
func notifyAddr(c *dns.Client, m *dns.Msg, s string) error { func notifyAddr(c *dns.Client, m *dns.Msg, s string) error {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
ret, err := middleware.Exchange(c, m, s) ret, _, err := c.Exchange(m, s)
if err != nil { if err != nil {
continue continue
} }

View file

@ -4,8 +4,6 @@ import (
"log" "log"
"time" "time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -75,7 +73,7 @@ func (z *Zone) shouldTransfer() (bool, error) {
Transfer: Transfer:
for _, tr := range z.TransferFrom { for _, tr := range z.TransferFrom {
Err = nil Err = nil
ret, err := middleware.Exchange(c, m, tr) ret, _, err := c.Exchange(m, tr)
if err != nil || ret.Rcode != dns.RcodeSuccess { if err != nil || ret.Rcode != dns.RcodeSuccess {
Err = err Err = err
continue continue

View file

@ -4,8 +4,8 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -134,7 +134,7 @@ func TestIsNotify(t *testing.T) {
z := new(Zone) z := new(Zone)
z.Expired = new(bool) z.Expired = new(bool)
z.origin = testZone z.origin = testZone
state := NewState(testZone, dns.TypeSOA) state := newRequest(testZone, dns.TypeSOA)
// need to set opcode // need to set opcode
state.Req.Opcode = dns.OpcodeNotify state.Req.Opcode = dns.OpcodeNotify
@ -148,9 +148,9 @@ func TestIsNotify(t *testing.T) {
} }
} }
func NewState(zone string, qtype uint16) middleware.State { func newRequest(zone string, qtype uint16) request.Request {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA) m.SetQuestion("example.com.", dns.TypeA)
m.SetEdns0(4097, true) m.SetEdns0(4097, true)
return middleware.State{W: &test.ResponseWriter{}, Req: m} return request.Request{W: &test.ResponseWriter{}, Req: m}
} }

View file

@ -1,9 +1,6 @@
package tree package tree
import ( import "github.com/miekg/dns"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
type Elem struct { type Elem struct {
m map[uint16][]dns.RR m map[uint16][]dns.RR
@ -91,8 +88,8 @@ func (e *Elem) Delete(rr dns.RR) (empty bool) {
return return
} }
// Less is a tree helper function that calls middleware.Less. // Less is a tree helper function that calls less.
func Less(a *Elem, name string) int { return middleware.Less(name, a.Name()) } func Less(a *Elem, name string) int { return less(name, a.Name()) }
// Assuming the same type and name this will check if the rdata is equal as well. // Assuming the same type and name this will check if the rdata is equal as well.
func equalRdata(a, b dns.RR) bool { func equalRdata(a, b dns.RR) bool {

View file

@ -1,4 +1,4 @@
package middleware package tree
import ( import (
"bytes" "bytes"
@ -6,7 +6,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// Less returns <0 when a is less than b, 0 when they are equal and // less returns <0 when a is less than b, 0 when they are equal and
// >0 when a is larger than b. // >0 when a is larger than b.
// The function orders names in DNSSEC canonical order: RFC 4034s section-6.1 // The function orders names in DNSSEC canonical order: RFC 4034s section-6.1
// //
@ -14,7 +14,7 @@ import (
// for a blog article on this implementation. // for a blog article on this implementation.
// //
// The values of a and b are *not* lowercased before the comparison! // The values of a and b are *not* lowercased before the comparison!
func Less(a, b string) int { func less(a, b string) int {
i := 1 i := 1
aj := len(a) aj := len(a)
bj := len(b) bj := len(b)

View file

@ -1,4 +1,4 @@
package middleware package tree
import ( import (
"sort" "sort"
@ -10,7 +10,7 @@ type set []string
func (p set) Len() int { return len(p) } func (p set) Len() int { return len(p) }
func (p set) Swap(i, j int) { p[i], p[j] = p[j], p[i] } func (p set) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p set) Less(i, j int) bool { d := Less(p[i], p[j]); return d <= 0 } func (p set) Less(i, j int) bool { d := less(p[i], p[j]); return d <= 0 }
func TestLess(t *testing.T) { func TestLess(t *testing.T) {
tests := []struct { tests := []struct {

View file

@ -5,7 +5,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -57,14 +57,14 @@ func TestLookupWildcard(t *testing.T) {
for _, tc := range wildcardTestCases { for _, tc := range wildcardTestCases {
m := tc.Msg() m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := fm.ServeDNS(ctx, rec, m) _, err := fm.ServeDNS(ctx, rec, m)
if err != nil { if err != nil {
t.Errorf("expected no error, got %v\n", err) t.Errorf("expected no error, got %v\n", err)
return return
} }
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer)) sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns)) sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra)) sort.Sort(test.RRSet(resp.Extra))

View file

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"log" "log"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -18,7 +18,7 @@ type (
// Serve an AXFR (and fallback of IXFR) as well. // Serve an AXFR (and fallback of IXFR) as well.
func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
if !x.TransferAllowed(state) { if !x.TransferAllowed(state) {
return dns.RcodeServerFailure, nil return dns.RcodeServerFailure, nil
} }

View file

@ -8,8 +8,8 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/file/tree" "github.com/miekg/coredns/middleware/file/tree"
"github.com/miekg/coredns/request"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -102,7 +102,7 @@ func (z *Zone) Insert(r dns.RR) error {
func (z *Zone) Delete(r dns.RR) { z.Tree.Delete(r) } func (z *Zone) Delete(r dns.RR) { z.Tree.Delete(r) }
// TransferAllowed checks if incoming request for transferring the zone is allowed according to the ACLs. // TransferAllowed checks if incoming request for transferring the zone is allowed according to the ACLs.
func (z *Zone) TransferAllowed(state middleware.State) bool { func (z *Zone) TransferAllowed(req request.Request) bool {
for _, t := range z.TransferTo { for _, t := range z.TransferTo {
if t == "*" { if t == "*" {
return true return true

View file

@ -1,19 +0,0 @@
package middleware
import (
"os"
"strings"
"testing"
)
func TestfsPath(t *testing.T) {
if actual := fsPath(); !strings.HasSuffix(actual, ".coredns") {
t.Errorf("Expected path to be a .coredns folder, got: %v", actual)
}
os.Setenv("COREDNSPATH", "testpath")
defer os.Setenv("COREDNSPATH", "")
if actual, expected := fsPath(), "testpath"; actual != expected {
t.Errorf("Expected path to be %v, got: %v", expected, actual)
}
}

View file

@ -1,36 +0,0 @@
package middleware
import (
"net"
"strings"
"github.com/miekg/dns"
)
// Host represents a host from the Corefile, may contain port.
type (
Host string
Addr string
)
// Normalize will return the host portion of host, stripping
// of any port. The host will also be fully qualified and lowercased.
func (h Host) Normalize() string {
// separate host and port
host, _, err := net.SplitHostPort(string(h))
if err != nil {
host, _, _ = net.SplitHostPort(string(h) + ":")
}
return strings.ToLower(dns.Fqdn(host))
}
// Normalize will return a normalized address, if not port is specified
// port 53 is added, otherwise the port will be left as is.
func (a Addr) Normalize() string {
// separate host and port
addr, port, err := net.SplitHostPort(string(a))
if err != nil {
addr, port, _ = net.SplitHostPort(string(a) + ":53")
}
return net.JoinHostPort(addr, port)
}

View file

@ -2,16 +2,17 @@ package kubernetes
import ( import (
"fmt" "fmt"
"strings"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsutil"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
if state.QClass() != dns.ClassINET { if state.QClass() != dns.ClassINET {
return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET") return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET")
} }
@ -21,8 +22,8 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
// TODO: find an alternative to this block // TODO: find an alternative to this block
if strings.HasSuffix(state.Name(), arpaSuffix) { ip := dnsutil.ExtractAddressFromReverse(state.Name())
ip, _ := extractIP(state.Name()) if ip != "" {
records := k.getServiceRecordForIP(ip, state.Name()) records := k.getServiceRecordForIP(ip, state.Name())
if len(records) > 0 { if len(records) > 0 {
srvPTR := &records[0] srvPTR := &records[0]
@ -100,7 +101,7 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M
} }
// NoData write a nodata response to the client. // NoData write a nodata response to the client.
func (k Kubernetes) Err(zone string, rcode int, state middleware.State) (int, error) { func (k Kubernetes) Err(zone string, rcode int, state request.Request) (int, error) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetRcode(state.Req, rcode) m.SetRcode(state.Req, rcode)
m.Ns = []dns.RR{k.SOA(zone, state)} m.Ns = []dns.RR{k.SOA(zone, state)}

View file

@ -4,13 +4,13 @@ package kubernetes
import ( import (
"errors" "errors"
"log" "log"
"strings"
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/kubernetes/msg" "github.com/miekg/coredns/middleware/kubernetes/msg"
"github.com/miekg/coredns/middleware/kubernetes/nametemplate" "github.com/miekg/coredns/middleware/kubernetes/nametemplate"
"github.com/miekg/coredns/middleware/kubernetes/util" "github.com/miekg/coredns/middleware/kubernetes/util"
"github.com/miekg/coredns/middleware/pkg/dnsutil"
"github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -100,8 +100,8 @@ func (k *Kubernetes) getZoneForName(name string) (string, []string) {
func (k *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) { func (k *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) {
// TODO: refector this. // TODO: refector this.
// Right now GetNamespaceFromSegmentArray do not supports PRE queries // Right now GetNamespaceFromSegmentArray do not supports PRE queries
if strings.HasSuffix(name, arpaSuffix) { ip := dnsutil.ExtractAddressFromReverse(name)
ip, _ := extractIP(name) if ip != "" {
records := k.getServiceRecordForIP(ip, name) records := k.getServiceRecordForIP(ip, name)
return records, nil return records, nil
} }

View file

@ -4,21 +4,17 @@ import (
"fmt" "fmt"
"math" "math"
"net" "net"
"strings"
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/kubernetes/msg" "github.com/miekg/coredns/middleware/kubernetes/msg"
"github.com/miekg/coredns/middleware/pkg/dnsutil"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
const ( func (k Kubernetes) records(state request.Request, exact bool) ([]msg.Service, error) {
// arpaSuffix is the standard suffix for PTR IP reverse lookups.
arpaSuffix = ".in-addr.arpa."
)
func (k Kubernetes) records(state middleware.State, exact bool) ([]msg.Service, error) {
services, err := k.Records(state.Name(), exact) services, err := k.Records(state.Name(), exact)
if err != nil { if err != nil {
return nil, err return nil, err
@ -28,7 +24,7 @@ func (k Kubernetes) records(state middleware.State, exact bool) ([]msg.Service,
return services, nil return services, nil
} }
func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) { func (k Kubernetes) A(zone string, state request.Request, previousRecords []dns.RR) (records []dns.RR, err error) {
services, err := k.records(state, false) services, err := k.records(state, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -49,11 +45,11 @@ func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns
// don't add it, and just continue // don't add it, and just continue
continue continue
} }
if isDuplicateCNAME(newRecord, previousRecords) { if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue continue
} }
state1 := copyState(state, serv.Host, state.QType()) state1 := state.NewWithQuestion(serv.Host, state.QType())
nextRecords, err := k.A(zone, state1, append(previousRecords, newRecord)) nextRecords, err := k.A(zone, state1, append(previousRecords, newRecord))
if err == nil { if err == nil {
@ -87,7 +83,7 @@ func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns
return records, nil return records, nil
} }
func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) { func (k Kubernetes) AAAA(zone string, state request.Request, previousRecords []dns.RR) (records []dns.RR, err error) {
services, err := k.records(state, false) services, err := k.records(state, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -108,11 +104,11 @@ func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords []
// don't add it, and just continue // don't add it, and just continue
continue continue
} }
if isDuplicateCNAME(newRecord, previousRecords) { if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue continue
} }
state1 := copyState(state, serv.Host, state.QType()) state1 := state.NewWithQuestion(serv.Host, state.QType())
nextRecords, err := k.AAAA(zone, state1, append(previousRecords, newRecord)) nextRecords, err := k.AAAA(zone, state1, append(previousRecords, newRecord))
if err == nil { if err == nil {
@ -149,7 +145,7 @@ func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords []
// SRV returns SRV records from kubernetes. // SRV returns SRV records from kubernetes.
// If the Target is not a name but an IP address, a name is created on the fly. // If the Target is not a name but an IP address, a name is created on the fly.
func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR, extra []dns.RR, err error) { func (k Kubernetes) SRV(zone string, state request.Request) (records []dns.RR, extra []dns.RR, err error) {
services, err := k.records(state, false) services, err := k.records(state, false)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -207,7 +203,7 @@ func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR,
} }
// Internal name, we should have some info on them, either v4 or v6 // Internal name, we should have some info on them, either v4 or v6
// Clients expect a complete answer, because we are a recursor in their view. // Clients expect a complete answer, because we are a recursor in their view.
state1 := copyState(state, srv.Target, dns.TypeA) state1 := state.NewWithQuestion(srv.Target, dns.TypeA)
addr, e1 := k.A(zone, state1, nil) addr, e1 := k.A(zone, state1, nil)
if e1 == nil { if e1 == nil {
extra = append(extra, addr...) extra = append(extra, addr...)
@ -231,21 +227,21 @@ func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR,
} }
// Returning MX records from kubernetes not implemented. // Returning MX records from kubernetes not implemented.
func (k Kubernetes) MX(zone string, state middleware.State) (records []dns.RR, extra []dns.RR, err error) { func (k Kubernetes) MX(zone string, state request.Request) (records []dns.RR, extra []dns.RR, err error) {
return nil, nil, err return nil, nil, err
} }
// Returning CNAME records from kubernetes not implemented. // Returning CNAME records from kubernetes not implemented.
func (k Kubernetes) CNAME(zone string, state middleware.State) (records []dns.RR, err error) { func (k Kubernetes) CNAME(zone string, state request.Request) (records []dns.RR, err error) {
return nil, err return nil, err
} }
// Returning TXT records from kubernetes not implemented. // Returning TXT records from kubernetes not implemented.
func (k Kubernetes) TXT(zone string, state middleware.State) (records []dns.RR, err error) { func (k Kubernetes) TXT(zone string, state request.Request) (records []dns.RR, err error) {
return nil, err return nil, err
} }
func (k Kubernetes) NS(zone string, state middleware.State) (records, extra []dns.RR, err error) { func (k Kubernetes) NS(zone string, state request.Request) (records, extra []dns.RR, err error) {
// NS record for this zone live in a special place, ns.dns.<zone>. Fake our lookup. // NS record for this zone live in a special place, ns.dns.<zone>. Fake our lookup.
// only a tad bit fishy... // only a tad bit fishy...
old := state.QName() old := state.QName()
@ -278,7 +274,7 @@ func (k Kubernetes) NS(zone string, state middleware.State) (records, extra []dn
} }
// SOA Record returns a SOA record. // SOA Record returns a SOA record.
func (k Kubernetes) SOA(zone string, state middleware.State) *dns.SOA { func (k Kubernetes) SOA(zone string, state request.Request) *dns.SOA {
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: 300, Class: dns.ClassINET} header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: 300, Class: dns.ClassINET}
return &dns.SOA{Hdr: header, return &dns.SOA{Hdr: header,
Mbox: "hostmaster." + zone, Mbox: "hostmaster." + zone,
@ -291,9 +287,9 @@ func (k Kubernetes) SOA(zone string, state middleware.State) *dns.SOA {
} }
} }
func (k Kubernetes) PTR(zone string, state middleware.State) ([]dns.RR, error) { func (k Kubernetes) PTR(zone string, state request.Request) ([]dns.RR, error) {
reverseIP, ok := extractIP(state.Name()) reverseIP := dnsutil.ExtractAddressFromReverse(state.Name())
if !ok { if reverseIP == "" {
return nil, fmt.Errorf("does not support reverse lookup for %s", state.QName()) return nil, fmt.Errorf("does not support reverse lookup for %s", state.QName())
} }
@ -318,41 +314,3 @@ func (k Kubernetes) PTR(zone string, state middleware.State) ([]dns.RR, error) {
} }
return records, nil return records, nil
} }
func isDuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
for _, rec := range records {
if v, ok := rec.(*dns.CNAME); ok {
if v.Target == r.Target {
return true
}
}
}
return false
}
func copyState(state middleware.State, target string, typ uint16) middleware.State {
state1 := middleware.State{W: state.W, Req: state.Req.Copy()}
state1.Req.Question[0] = dns.Question{Name: dns.Fqdn(target), Qtype: dns.ClassINET, Qclass: typ}
return state1
}
// extractIP turns a standard PTR reverse record lookup name
// into an IP address
func extractIP(reverseName string) (string, bool) {
if !strings.HasSuffix(reverseName, arpaSuffix) {
return "", false
}
search := strings.TrimSuffix(reverseName, arpaSuffix)
// reverse the segments and then combine them
segments := reverseArray(strings.Split(search, "."))
return strings.Join(segments, "."), true
}
func reverseArray(arr []string) []string {
for i := 0; i < len(arr)/2; i++ {
j := len(arr) - i - 1
arr[i], arr[j] = arr[j], arr[i]
}
return arr
}

View file

@ -68,7 +68,7 @@ func kubernetesParse(c *caddy.Controller) (Kubernetes, error) {
} }
k8s.Zones = NormalizeZoneList(zones) k8s.Zones = NormalizeZoneList(zones)
middleware.Zones(k8s.Zones).FullyQualify() middleware.Zones(k8s.Zones).Normalize()
if k8s.Zones == nil || len(k8s.Zones) < 1 { if k8s.Zones == nil || len(k8s.Zones) < 1 {
err = errors.New("Zone name must be provided for kubernetes middleware.") err = errors.New("Zone name must be provided for kubernetes middleware.")

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -56,7 +57,7 @@ func TestLoadBalance(t *testing.T) {
}, },
} }
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
for i, test := range tests { for i, test := range tests {
req := new(dns.Msg) req := new(dns.Msg)
@ -71,7 +72,7 @@ func TestLoadBalance(t *testing.T) {
} }
cname := 0 cname := 0
for _, r := range rec.Msg().Answer { for _, r := range rec.Msg.Answer {
if r.Header().Rrtype != dns.TypeCNAME { if r.Header().Rrtype != dns.TypeCNAME {
break break
} }
@ -81,7 +82,7 @@ func TestLoadBalance(t *testing.T) {
t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname) t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname)
} }
cname = 0 cname = 0
for _, r := range rec.Msg().Extra { for _, r := range rec.Msg.Extra {
if r.Header().Rrtype != dns.TypeCNAME { if r.Header().Rrtype != dns.TypeCNAME {
break break
} }

View file

@ -7,6 +7,11 @@ import (
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/metrics" "github.com/miekg/coredns/middleware/metrics"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/pkg/rcode"
"github.com/miekg/coredns/middleware/pkg/replacer"
"github.com/miekg/coredns/middleware/pkg/roller"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -20,32 +25,30 @@ type Logger struct {
} }
func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
for _, rule := range l.Rules { for _, rule := range l.Rules {
if middleware.Name(rule.NameScope).Matches(state.Name()) { if middleware.Name(rule.NameScope).Matches(state.Name()) {
responseRecorder := middleware.NewResponseRecorder(w) responseRecorder := dnsrecorder.New(w)
rcode, err := l.Next.ServeDNS(ctx, responseRecorder, r) rc, err := l.Next.ServeDNS(ctx, responseRecorder, r)
if rcode > 0 { if rc > 0 {
// There was an error up the chain, but no response has been written yet. // There was an error up the chain, but no response has been written yet.
// The error must be handled here so the log entry will record the response size. // The error must be handled here so the log entry will record the response size.
if l.ErrorFunc != nil { if l.ErrorFunc != nil {
l.ErrorFunc(responseRecorder, r, rcode) l.ErrorFunc(responseRecorder, r, rc)
} else { } else {
rc := middleware.RcodeToString(rcode)
answer := new(dns.Msg) answer := new(dns.Msg)
answer.SetRcode(r, rcode) answer.SetRcode(r, rc)
state.SizeAndDo(answer) state.SizeAndDo(answer)
metrics.Report(state, metrics.Dropped, rc, answer.Len(), time.Now()) metrics.Report(state, metrics.Dropped, rcode.ToString(rc), answer.Len(), time.Now())
w.WriteMsg(answer) w.WriteMsg(answer)
} }
rcode = 0 rc = 0
} }
rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue) rep := replacer.New(r, responseRecorder, CommonLogEmptyValue)
rule.Log.Println(rep.Replace(rule.Format)) rule.Log.Println(rep.Replace(rule.Format))
return rcode, err return rc, err
} }
} }
@ -58,7 +61,7 @@ type Rule struct {
OutputFile string OutputFile string
Format string Format string
Log *log.Logger Log *log.Logger
Roller *middleware.LogRoller Roller *roller.LogRoller
} }
const ( const (

View file

@ -6,7 +6,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -37,7 +37,7 @@ func TestLoggedStatus(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion("example.org.", dns.TypeA) r.SetQuestion("example.org.", dns.TypeA)
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
rcode, _ := logger.ServeDNS(ctx, rec, r) rcode, _ := logger.ServeDNS(ctx, rec, r)
if rcode != 0 { if rcode != 0 {

View file

@ -6,7 +6,7 @@ import (
"os" "os"
"github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/roller"
"github.com/hashicorp/go-syslog" "github.com/hashicorp/go-syslog"
"github.com/mholt/caddy" "github.com/mholt/caddy"
@ -75,13 +75,13 @@ func logParse(c *caddy.Controller) ([]Rule, error) {
for c.Next() { for c.Next() {
args := c.RemainingArgs() args := c.RemainingArgs()
var logRoller *middleware.LogRoller var logRoller *roller.LogRoller
if c.NextBlock() { if c.NextBlock() {
if c.Val() == "rotate" { if c.Val() == "rotate" {
if c.NextArg() { if c.NextArg() {
if c.Val() == "{" { if c.Val() == "{" {
var err error var err error
logRoller, err = middleware.ParseRoller(c) logRoller, err = roller.Parse(c)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -3,7 +3,7 @@ package log
import ( import (
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/roller"
"github.com/mholt/caddy" "github.com/mholt/caddy"
) )
@ -68,7 +68,7 @@ func TestLogParse(t *testing.T) {
NameScope: ".", NameScope: ".",
OutputFile: "access.log", OutputFile: "access.log",
Format: DefaultLogFormat, Format: DefaultLogFormat,
Roller: &middleware.LogRoller{ Roller: &roller.LogRoller{
MaxSize: 2, MaxSize: 2,
MaxAge: 10, MaxAge: 10,
MaxBackups: 3, MaxBackups: 3,

View file

@ -4,13 +4,16 @@ import (
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/pkg/rcode"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
func (m Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (m Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r} state := request.Request{W: w, Req: r}
qname := state.QName() qname := state.QName()
zone := middleware.Zones(m.ZoneNames).Matches(qname) zone := middleware.Zones(m.ZoneNames).Matches(qname)
@ -19,35 +22,35 @@ func (m Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
} }
// Record response to get status code and size of the reply. // Record response to get status code and size of the reply.
rw := middleware.NewResponseRecorder(w) rw := dnsrecorder.New(w)
status, err := m.Next.ServeDNS(ctx, rw, r) status, err := m.Next.ServeDNS(ctx, rw, r)
Report(state, zone, rw.Rcode(), rw.Size(), rw.Start()) Report(state, zone, rcode.ToString(rw.Rcode), rw.Size, rw.Start)
return status, err return status, err
} }
// Report is a plain reporting function that the server can use for REFUSED and other // Report is a plain reporting function that the server can use for REFUSED and other
// queries that are turned down because they don't match any middleware. // queries that are turned down because they don't match any middleware.
func Report(state middleware.State, zone, rcode string, size int, start time.Time) { func Report(req request.Request, zone, rcode string, size int, start time.Time) {
if requestCount == nil { if requestCount == nil {
// no metrics are enabled // no metrics are enabled
return return
} }
// Proto and Family // Proto and Family
net := state.Proto() net := req.Proto()
fam := "1" fam := "1"
if state.Family() == 2 { if req.Family() == 2 {
fam = "2" fam = "2"
} }
typ := state.QType() typ := req.QType()
requestCount.WithLabelValues(zone, net, fam).Inc() requestCount.WithLabelValues(zone, net, fam).Inc()
requestDuration.WithLabelValues(zone).Observe(float64(time.Since(start) / time.Millisecond)) requestDuration.WithLabelValues(zone).Observe(float64(time.Since(start) / time.Millisecond))
if state.Do() { if req.Do() {
requestDo.WithLabelValues(zone).Inc() requestDo.WithLabelValues(zone).Inc()
} }
@ -59,10 +62,10 @@ func Report(state middleware.State, zone, rcode string, size int, start time.Tim
if typ == dns.TypeIXFR || typ == dns.TypeAXFR { if typ == dns.TypeIXFR || typ == dns.TypeAXFR {
responseTransferSize.WithLabelValues(zone, net).Observe(float64(size)) responseTransferSize.WithLabelValues(zone, net).Observe(float64(size))
requestTransferSize.WithLabelValues(zone, net).Observe(float64(state.Size())) requestTransferSize.WithLabelValues(zone, net).Observe(float64(req.Size()))
} else { } else {
responseSize.WithLabelValues(zone, net).Observe(float64(size)) responseSize.WithLabelValues(zone, net).Observe(float64(size))
requestSize.WithLabelValues(zone, net).Observe(float64(state.Size())) requestSize.WithLabelValues(zone, net).Observe(float64(req.Size()))
} }
responseRcode.WithLabelValues(zone, rcode).Inc() responseRcode.WithLabelValues(zone, rcode).Inc()

View file

@ -1,25 +0,0 @@
package middleware
import (
"strings"
"github.com/miekg/dns"
)
// Name represents a domain name.
type Name string
// Matches checks to see if other is a subdomain (or the same domain) of n.
// This method assures that names can be easily and consistently matched.
func (n Name) Matches(child string) bool {
if dns.Name(n) == dns.Name(child) {
return true
}
return dns.IsSubDomain(string(n), child)
}
// Normalize lowercases and makes n fully qualified.
func (n Name) Normalize() string {
return strings.ToLower(dns.Fqdn(string(n)))
}

78
middleware/normalize.go Normal file
View file

@ -0,0 +1,78 @@
package middleware
import (
"net"
"strings"
"github.com/miekg/dns"
)
type Zones []string
// Matches checks is qname is a subdomain of any of the zones in z. The match
// will return the most specific zones that matches other. The empty string
// signals a not found condition.
func (z Zones) Matches(qname string) string {
zone := ""
for _, zname := range z {
if dns.IsSubDomain(zname, qname) {
// TODO(miek): hmm, add test for this case
if len(zname) > len(zone) {
zone = zname
}
}
}
return zone
}
// Normalize fully qualifies all zones in z.
func (z Zones) Normalize() {
for i, _ := range z {
z[i] = Name(z[i]).Normalize()
}
}
// Name represents a domain name.
type Name string
// Matches checks to see if other is a subdomain (or the same domain) of n.
// This method assures that names can be easily and consistently matched.
func (n Name) Matches(child string) bool {
if dns.Name(n) == dns.Name(child) {
return true
}
return dns.IsSubDomain(string(n), child)
}
// Normalize lowercases and makes n fully qualified.
func (n Name) Normalize() string { return strings.ToLower(dns.Fqdn(string(n))) }
// Host represents a host from the Corefile, may contain port.
type (
Host string
Addr string
)
// Normalize will return the host portion of host, stripping
// of any port. The host will also be fully qualified and lowercased.
func (h Host) Normalize() string {
// separate host and port
host, _, err := net.SplitHostPort(string(h))
if err != nil {
host, _, _ = net.SplitHostPort(string(h) + ":")
}
return Name(host).Normalize()
}
// Normalize will return a normalized address, if not port is specified
// port 53 is added, otherwise the port will be left as is.
func (a Addr) Normalize() string {
// separate host and port
addr, port, err := net.SplitHostPort(string(a))
if err != nil {
addr, port, _ = net.SplitHostPort(string(a) + ":53")
}
// TODO(miek): lowercase it?
return net.JoinHostPort(addr, port)
}

View file

@ -0,0 +1,57 @@
package dnsrecorder
import (
"time"
"github.com/miekg/dns"
)
// Recorder is a type of ResponseWriter that captures
// the rcode code written to it and also the size of the message
// written in the response. A rcode code does not have
// to be written, however, in which case 0 must be assumed.
// It is best to have the constructor initialize this type
// with that default status code.
type Recorder struct {
dns.ResponseWriter
Rcode int
Size int
Msg *dns.Msg
Start time.Time
}
// New makes and returns a new Recorder,
// which captures the DNS rcode from the ResponseWriter
// and also the length of the response message written through it.
func New(w dns.ResponseWriter) *Recorder {
return &Recorder{
ResponseWriter: w,
Rcode: 0,
Msg: nil,
Start: time.Now(),
}
}
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *Recorder) WriteMsg(res *dns.Msg) error {
r.Rcode = res.Rcode
// We may get called multiple times (axfr for instance).
// Save the last message, but add the sizes.
r.Size += res.Len()
r.Msg = res
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the size of the message that gets written.
func (r *Recorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.Size += n
}
return n, err
}
// Hijack implements dns.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (r *Recorder) Hijack() { r.ResponseWriter.Hijack(); return }

View file

@ -1,4 +1,4 @@
package middleware package dnsrecorder
/* /*
func TestNewResponseRecorder(t *testing.T) { func TestNewResponseRecorder(t *testing.T) {

View file

@ -0,0 +1,15 @@
package dnsutil
import "github.com/miekg/dns"
// DuplicateCNAME returns true if r already exists in records.
func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
for _, rec := range records {
if v, ok := rec.(*dns.CNAME); ok {
if v.Target == r.Target {
return true
}
}
}
return false
}

View file

@ -0,0 +1,40 @@
package dnsutil
import "strings"
// ExtractAddressFromReverse turns a standard PTR reverse record name
// into an IP address. This works for ipv4 or ipv6.
//
// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion
// failes the empty string is returned.
func ExtractAddressFromReverse(reverseName string) string {
search := ""
switch {
case strings.HasSuffix(reverseName, v4arpaSuffix):
search = strings.TrimSuffix(reverseName, v4arpaSuffix)
case strings.HasSuffix(reverseName, v6arpaSuffix):
search = strings.TrimSuffix(reverseName, v6arpaSuffix)
default:
return ""
}
// Reverse the segments and then combine them.
segments := reverse(strings.Split(search, "."))
return strings.Join(segments, ".")
}
func reverse(slice []string) []string {
for i := 0; i < len(slice)/2; i++ {
j := len(slice) - i - 1
slice[i], slice[j] = slice[j], slice[i]
}
return slice
}
const (
// v4arpaSuffix is the reverse tree suffix for v4 IP addresses.
v4arpaSuffix = ".in-addr.arpa."
// v6arpaSuffix is the reverse tree suffix for v6 IP addresses.
v6arpaSuffix = ".ip6.arpa."
)

View file

@ -1,4 +1,4 @@
package middleware package edns
import ( import (
"errors" "errors"
@ -6,11 +6,11 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// Edns0Version checks the EDNS version in the request. If error // Version checks the EDNS version in the request. If error
// is nil everything is OK and we can invoke the middleware. If non-nil, the // is nil everything is OK and we can invoke the middleware. If non-nil, the
// returned Msg is valid to be returned to the client (and should). For some // returned Msg is valid to be returned to the client (and should). For some
// reason this response should not contain a question RR in the question section. // reason this response should not contain a question RR in the question section.
func Edns0Version(req *dns.Msg) (*dns.Msg, error) { func Version(req *dns.Msg) (*dns.Msg, error) {
opt := req.IsEdns0() opt := req.IsEdns0()
if opt == nil { if opt == nil {
return nil, nil return nil, nil
@ -33,8 +33,8 @@ func Edns0Version(req *dns.Msg) (*dns.Msg, error) {
return m, errors.New("EDNS0 BADVERS") return m, errors.New("EDNS0 BADVERS")
} }
// edns0Size returns a normalized size based on proto. // Size returns a normalized size based on proto.
func edns0Size(proto string, size int) int { func Size(proto string, size int) int {
if proto == "tcp" { if proto == "tcp" {
return dns.MaxMsgSize return dns.MaxMsgSize
} }

View file

@ -1,4 +1,4 @@
package middleware package edns
import ( import (
"testing" "testing"
@ -6,21 +6,21 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
func TestEdns0Version(t *testing.T) { func TestVersion(t *testing.T) {
m := ednsMsg() m := ednsMsg()
m.Extra[0].(*dns.OPT).SetVersion(2) m.Extra[0].(*dns.OPT).SetVersion(2)
_, err := Edns0Version(m) _, err := Version(m)
if err == nil { if err == nil {
t.Errorf("expected wrong version, but got OK") t.Errorf("expected wrong version, but got OK")
} }
} }
func TestEdns0VersionNoEdns(t *testing.T) { func TestVersionNoEdns(t *testing.T) {
m := ednsMsg() m := ednsMsg()
m.Extra = nil m.Extra = nil
_, err := Edns0Version(m) _, err := Version(m)
if err != nil { if err != nil {
t.Errorf("expected no error, but got one: %s", err) t.Errorf("expected no error, but got one: %s", err)
} }

View file

@ -1,4 +1,4 @@
package middleware package rcode
import ( import (
"strconv" "strconv"
@ -6,7 +6,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
func RcodeToString(rcode int) string { func ToString(rcode int) string {
if str, ok := dns.RcodeToString[rcode]; ok { if str, ok := dns.RcodeToString[rcode]; ok {
return str return str
} }

View file

@ -1,10 +1,13 @@
package middleware package replacer
import ( import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -22,46 +25,43 @@ type replacer struct {
emptyValue string emptyValue string
} }
// NewReplacer makes a new replacer based on r and rr. // New makes a new replacer based on r and rr.
// Do not create a new replacer until r and rr have all // Do not create a new replacer until r and rr have all
// the needed values, because this function copies those // the needed values, because this function copies those
// values into the replacer. rr may be nil if it is not // values into the replacer. rr may be nil if it is not
// available. emptyValue should be the string that is used // available. emptyValue should be the string that is used
// in place of empty string (can still be empty string). // in place of empty string (can still be empty string).
func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer { func New(r *dns.Msg, rr *dnsrecorder.Recorder, emptyValue string) Replacer {
state := State{W: rr, Req: r} req := request.Request{W: rr, Req: r}
rep := replacer{ rep := replacer{
replacements: map[string]string{ replacements: map[string]string{
"{type}": state.Type(), "{type}": req.Type(),
"{name}": state.Name(), "{name}": req.Name(),
"{class}": state.Class(), "{class}": req.Class(),
"{proto}": state.Proto(), "{proto}": req.Proto(),
"{when}": func() string { "{when}": func() string {
return time.Now().Format(timeFormat) return time.Now().Format(timeFormat)
}(), }(),
"{remote}": state.IP(), "{remote}": req.IP(),
"{port}": func() string { "{port}": req.Port(),
p, _ := state.Port()
return p
}(),
}, },
emptyValue: emptyValue, emptyValue: emptyValue,
} }
if rr != nil { if rr != nil {
rcode := dns.RcodeToString[rr.rcode] rcode := dns.RcodeToString[rr.Rcode]
if rcode == "" { if rcode == "" {
rcode = strconv.Itoa(rr.rcode) rcode = strconv.Itoa(rr.Rcode)
} }
rep.replacements["{rcode}"] = rcode rep.replacements["{rcode}"] = rcode
rep.replacements["{size}"] = strconv.Itoa(rr.size) rep.replacements["{size}"] = strconv.Itoa(rr.Size)
rep.replacements["{duration}"] = time.Since(rr.start).String() rep.replacements["{duration}"] = time.Since(rr.Start).String()
} }
// Header placeholders (case-insensitive) // Header placeholders (case-insensitive)
rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id)) rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id))
rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(int(r.Opcode)) rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(int(r.Opcode))
rep.replacements[headerReplacer+"do}"] = boolToString(state.Do()) rep.replacements[headerReplacer+"do}"] = boolToString(req.Do())
rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(state.Size()) rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(req.Size())
return rep return rep
} }

View file

@ -1,4 +1,4 @@
package middleware package replacer
/* /*
func TestNewReplacer(t *testing.T) { func TestNewReplacer(t *testing.T) {

View file

@ -1,19 +1,19 @@
package middleware package response
import "github.com/miekg/dns" import "github.com/miekg/dns"
type MsgType int type Type int
const ( const (
Success MsgType = iota Success Type = iota
NameError // NXDOMAIN in header, SOA in auth. NameError // NXDOMAIN in header, SOA in auth.
NoData // NOERROR in header, SOA in auth. NoData // NOERROR in header, SOA in auth.
Delegation // NOERROR in header, NS in auth, optionally fluff in additional (not checked). Delegation // NOERROR in header, NS in auth, optionally fluff in additional (not checked).
OtherError // Don't cache these. OtherError // Don't cache these.
) )
// Classify classifies a message, it returns the MessageType. // Classify classifies a message, it returns the Type.
func Classify(m *dns.Msg) (MsgType, *dns.OPT) { func Classify(m *dns.Msg) (Type, *dns.OPT) {
opt := m.IsEdns0() opt := m.IsEdns0()
if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess { if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess {

View file

@ -1,4 +1,4 @@
package middleware package response
import ( import (
"testing" "testing"

View file

@ -1,4 +1,4 @@
package middleware package roller
import ( import (
"io" "io"
@ -8,7 +8,7 @@ import (
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
) )
func ParseRoller(c *caddy.Controller) (*LogRoller, error) { func Parse(c *caddy.Controller) (*LogRoller, error) {
var size, age, keep int var size, age, keep int
// This is kind of a hack to support nested blocks: // This is kind of a hack to support nested blocks:
// As we are already in a block: either log or errors, // As we are already in a block: either log or errors,

View file

@ -1,25 +1,39 @@
package middleware package storage
import ( import (
"net/http" "net/http"
"os" "os"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
) )
// dir wraps http.Dir that restrict file access to a specific directory tree. // dir wraps an http.Dir that restrict file access to a specific directory tree, see http.Dir's documentation
// for methods for accessing files.
type dir http.Dir type dir http.Dir
// CoreDir is the directory where middleware can store assets, like zone files after a zone transfer // CoreDir is the directory where middleware can store assets, like zone files after a zone transfer
// or public and private keys or anything else a middleware might need. The convention is to place // or public and private keys or anything else a middleware might need. The convention is to place
// assets in a subdirectory named after the fully qualified zone. // assets in a subdirectory named after the zone prefixed with "D", to prevent the root zone become a hidden directory.
// //
// example.org./Kexample<something>.key // Dexample.org/Kexample.org<something>.key
//
// Note that subzone(s) under example.org are places in the own directory under CoreDir:
//
// Dexample.org/...
// Db.example.org/...
// //
// CoreDir will default to "$HOME/.coredns" on Unix, but it's location can be overriden with the COREDNSPATH // CoreDir will default to "$HOME/.coredns" on Unix, but it's location can be overriden with the COREDNSPATH
// environment variable. // environment variable.
var CoreDir dir = dir(fsPath()) var CoreDir dir = dir(fsPath())
func (d dir) Zone(z string) dir {
if z != "." && z[len(z)-2] == '.' {
return dir(path.Join(string(d), "D"+z[:len(z)-1]))
}
return dir(path.Join(string(d), "D"+z))
}
// fsPath returns the path to the directory where the application may store data. // fsPath returns the path to the directory where the application may store data.
// If COREDNSPATH env variable. is set, that value is used. Otherwise, the path is // If COREDNSPATH env variable. is set, that value is used. Otherwise, the path is
// the result of evaluating "$HOME/.coredns". // the result of evaluating "$HOME/.coredns".

View file

@ -0,0 +1,42 @@
package storage
import (
"os"
"path"
"strings"
"testing"
)
func TestfsPath(t *testing.T) {
if actual := fsPath(); !strings.HasSuffix(actual, ".coredns") {
t.Errorf("Expected path to be a .coredns folder, got: %v", actual)
}
os.Setenv("COREDNSPATH", "testpath")
defer os.Setenv("COREDNSPATH", "")
if actual, expected := fsPath(), "testpath"; actual != expected {
t.Errorf("Expected path to be %v, got: %v", expected, actual)
}
}
func TestZone(t *testing.T) {
for _, ts := range []string{"example.org.", "example.org"} {
d := CoreDir.Zone(ts)
actual := path.Base(string(d))
expected := "D" + ts
if actual != expected {
t.Errorf("Expected path to be %v, got %v", actual, expected)
}
}
}
func TestZoneRoot(t *testing.T) {
for _, ts := range []string{"."} {
d := CoreDir.Zone(ts)
actual := path.Base(string(d))
expected := "D" + ts
if actual != expected {
t.Errorf("Expected path to be %v, got %v", actual, expected)
}
}
}

View file

@ -7,7 +7,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -54,18 +55,19 @@ func New(hosts []string) Proxy {
// Lookup will use name and type to forge a new message and will send that upstream. It will // Lookup will use name and type to forge a new message and will send that upstream. It will
// set any EDNS0 options correctly so that downstream will be able to process the reply. // set any EDNS0 options correctly so that downstream will be able to process the reply.
// Lookup is not suitable for forwarding request. Ssee for that. // Lookup is not suitable for forwarding request. Ssee for that.
func (p Proxy) Lookup(state middleware.State, name string, tpe uint16) (*dns.Msg, error) { func (p Proxy) Lookup(state request.Request, name string, tpe uint16) (*dns.Msg, error) {
req := new(dns.Msg) req := new(dns.Msg)
req.SetQuestion(name, tpe) req.SetQuestion(name, tpe)
state.SizeAndDo(req) state.SizeAndDo(req)
return p.lookup(state, req) return p.lookup(state, req)
} }
func (p Proxy) Forward(state middleware.State) (*dns.Msg, error) { func (p Proxy) Forward(state request.Request) (*dns.Msg, error) {
return p.lookup(state, state.Req) return p.lookup(state, state.Req)
} }
func (p Proxy) lookup(state middleware.State, r *dns.Msg) (*dns.Msg, error) { func (p Proxy) lookup(state request.Request, r *dns.Msg) (*dns.Msg, error) {
var ( var (
reply *dns.Msg reply *dns.Msg
err error err error
@ -84,9 +86,9 @@ func (p Proxy) lookup(state middleware.State, r *dns.Msg) (*dns.Msg, error) {
atomic.AddInt64(&host.Conns, 1) atomic.AddInt64(&host.Conns, 1)
if state.Proto() == "tcp" { if state.Proto() == "tcp" {
reply, err = middleware.Exchange(p.Client.TCP, r, host.Name) reply, _, err = p.Client.TCP.Exchange(r, host.Name)
} else { } else {
reply, err = middleware.Exchange(p.Client.UDP, r, host.Name) reply, _, err = p.Client.UDP.Exchange(r, host.Name)
} }
atomic.AddInt64(&host.Conns, -1) atomic.AddInt64(&host.Conns, -1)

View file

@ -2,7 +2,7 @@
package proxy package proxy
import ( import (
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -20,10 +20,10 @@ func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR)
) )
switch { switch {
case middleware.Proto(w) == "tcp": case request.Proto(w) == "tcp": // TODO(miek): keep this in request
reply, err = middleware.Exchange(p.Client.TCP, r, p.Host) reply, _, err = p.Client.TCP.Exchange(r, p.Host)
default: default:
reply, err = middleware.Exchange(p.Client.UDP, r, p.Host) reply, _, err = p.Client.UDP.Exchange(r, p.Host)
} }
if reply != nil && reply.Truncated { if reply != nil && reply.Truncated {

View file

@ -1,72 +0,0 @@
package middleware
import (
"time"
"github.com/miekg/dns"
)
// ResponseRecorder is a type of ResponseWriter that captures
// the rcode code written to it and also the size of the message
// written in the response. A rcode code does not have
// to be written, however, in which case 0 must be assumed.
// It is best to have the constructor initialize this type
// with that default status code.
type ResponseRecorder struct {
dns.ResponseWriter
rcode int
size int
msg *dns.Msg
start time.Time
}
// NewResponseRecorder makes and returns a new responseRecorder,
// which captures the DNS rcode from the ResponseWriter
// and also the length of the response message written through it.
func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder {
return &ResponseRecorder{
ResponseWriter: w,
rcode: 0,
msg: nil,
start: time.Now(),
}
}
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error {
r.rcode = res.Rcode
// We may get called multiple times (axfr for instance).
// Save the last message, but add the sizes.
r.size += res.Len()
r.msg = res
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the size of the message that gets written.
func (r *ResponseRecorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.size += n
}
return n, err
}
// Size returns the size.
func (r *ResponseRecorder) Size() int { return r.size }
// Rcode returns the rcode.
func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) }
// Start returns the start time of the ResponseRecorder.
func (r *ResponseRecorder) Start() time.Time { return r.start }
// Msg returns the written message from the ResponseRecorder.
func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg }
// Hijack implements dns.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (r *ResponseRecorder) Hijack() {
r.ResponseWriter.Hijack()
return
}

View file

@ -5,7 +5,8 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/pkg/replacer"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -25,8 +26,8 @@ func operatorError(operator string) error {
return fmt.Errorf("Invalid operator %v", operator) return fmt.Errorf("Invalid operator %v", operator)
} }
func newReplacer(r *dns.Msg) middleware.Replacer { func newReplacer(r *dns.Msg) replacer.Replacer {
return middleware.NewReplacer(r, nil, "") return replacer.New(r, nil, "")
} }
// condition is a rewrite condition. // condition is a rewrite condition.

View file

@ -117,143 +117,3 @@ func (s SimpleRule) Rewrite(r *dns.Msg) Result {
} }
return RewriteIgnored return RewriteIgnored
} }
/*
// ComplexRule is a rewrite rule based on a regular expression
type ComplexRule struct {
// Path base. Request to this path and subpaths will be rewritten
Base string
// Path to rewrite to
To string
// If set, neither performs rewrite nor proceeds
// with request. Only returns code.
Status int
// Extensions to filter by
Exts []string
// Rewrite conditions
Ifs []If
*regexp.Regexp
}
// NewComplexRule creates a new RegexpRule. It returns an error if regexp
// pattern (pattern) or extensions (ext) are invalid.
func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) {
// validate regexp if present
var r *regexp.Regexp
if pattern != "" {
var err error
r, err = regexp.Compile(pattern)
if err != nil {
return nil, err
}
}
// validate extensions if present
for _, v := range ext {
if len(v) < 2 || (len(v) < 3 && v[0] == '!') {
// check if no extension is specified
if v != "/" && v != "!/" {
return nil, fmt.Errorf("invalid extension %v", v)
}
}
}
return &ComplexRule{
Base: base,
To: to,
Status: status,
Exts: ext,
Ifs: ifs,
Regexp: r,
}, nil
}
// Rewrite rewrites the internal location of the current request.
func (r *ComplexRule) Rewrite(req *dns.Msg) (re Result) {
rPath := req.URL.Path
replacer := newReplacer(req)
// validate base
if !middleware.Path(rPath).Matches(r.Base) {
return
}
// validate extensions
if !r.matchExt(rPath) {
return
}
// validate regexp if present
if r.Regexp != nil {
// include trailing slash in regexp if present
start := len(r.Base)
if strings.HasSuffix(r.Base, "/") {
start--
}
matches := r.FindStringSubmatch(rPath[start:])
switch len(matches) {
case 0:
// no match
return
default:
// set regexp match variables {1}, {2} ...
for i := 1; i < len(matches); i++ {
replacer.Set(fmt.Sprint(i), matches[i])
}
}
}
// validate rewrite conditions
for _, i := range r.Ifs {
if !i.True(req) {
return
}
}
// if status is present, stop rewrite and return it.
if r.Status != 0 {
return RewriteStatus
}
// attempt rewrite
return To(fs, req, r.To, replacer)
}
// matchExt matches rPath against registered file extensions.
// Returns true if a match is found and false otherwise.
func (r *ComplexRule) matchExt(rPath string) bool {
f := filepath.Base(rPath)
ext := path.Ext(f)
if ext == "" {
ext = "/"
}
mustUse := false
for _, v := range r.Exts {
use := true
if v[0] == '!' {
use = false
v = v[1:]
}
if use {
mustUse = true
}
if ext == v {
return use
}
}
if mustUse {
return false
}
return true
}
*/

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -50,10 +51,10 @@ func TestRewrite(t *testing.T) {
m.SetQuestion(tc.from, tc.fromT) m.SetQuestion(tc.from, tc.fromT)
m.Question[0].Qclass = tc.fromC m.Question[0].Qclass = tc.fromC
rec := middleware.NewResponseRecorder(&test.ResponseWriter{}) rec := dnsrecorder.New(&test.ResponseWriter{})
rw.ServeDNS(ctx, rec, m) rw.ServeDNS(ctx, rec, m)
resp := rec.Msg()
resp := rec.Msg
if resp.Question[0].Name != tc.to { if resp.Question[0].Name != tc.to {
t.Errorf("Test %d: Expected Name to be '%s' but was '%s'", i, tc.to, resp.Question[0].Name) t.Errorf("Test %d: Expected Name to be '%s' but was '%s'", i, tc.to, resp.Question[0].Name)
} }

View file

@ -1,289 +0,0 @@
package middleware
import (
"testing"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
)
func TestStateDo(t *testing.T) {
st := testState()
st.Do()
if st.do == 0 {
t.Fatalf("expected st.do to be set")
}
}
func TestStateRemote(t *testing.T) {
st := testState()
if st.IP() != "10.240.0.1" {
t.Fatalf("wrong IP from state")
}
p, err := st.Port()
if err != nil {
t.Fatalf("failed to get Port from state")
}
if p != "40212" {
t.Fatalf("wrong port from state")
}
}
func BenchmarkStateDo(b *testing.B) {
st := testState()
for i := 0; i < b.N; i++ {
st.Do()
}
}
func BenchmarkStateSize(b *testing.B) {
st := testState()
for i := 0; i < b.N; i++ {
st.Size()
}
}
func testState() State {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.SetEdns0(4097, true)
return State{W: &test.ResponseWriter{}, Req: m}
}
/*
func TestHeader(t *testing.T) {
state := getContextOrFail(t)
headerKey, headerVal := "Header1", "HeaderVal1"
state.Req.Header.Add(headerKey, headerVal)
actualHeaderVal := state.Header(headerKey)
if actualHeaderVal != headerVal {
t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
}
missingHeaderVal := state.Header("not-existing")
if missingHeaderVal != "" {
t.Errorf("Expected empty header value, found %s", missingHeaderVal)
}
}
func TestIP(t *testing.T) {
state := getContextOrFail(t)
tests := []struct {
inputRemoteAddr string
expectedIP string
}{
// Test 0 - ipv4 with port
{"1.1.1.1:1111", "1.1.1.1"},
// Test 1 - ipv4 without port
{"1.1.1.1", "1.1.1.1"},
// Test 2 - ipv6 with port
{"[::1]:11", "::1"},
// Test 3 - ipv6 without port and brackets
{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
// Test 4 - ipv6 with zone and port
{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
state.Req.RemoteAddr = test.inputRemoteAddr
actualIP := state.IP()
if actualIP != test.expectedIP {
t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
}
}
}
func TestURL(t *testing.T) {
state := getContextOrFail(t)
inputURL := "http://localhost"
state.Req.RequestURI = inputURL
if inputURL != state.URI() {
t.Errorf("Expected url %s, found %s", inputURL, state.URI())
}
}
func TestHost(t *testing.T) {
tests := []struct {
input string
expectedHost string
shouldErr bool
}{
{
input: "localhost:123",
expectedHost: "localhost",
shouldErr: false,
},
{
input: "localhost",
expectedHost: "localhost",
shouldErr: false,
},
{
input: "[::]",
expectedHost: "",
shouldErr: true,
},
}
for _, test := range tests {
testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
}
}
func TestPort(t *testing.T) {
tests := []struct {
input string
expectedPort string
shouldErr bool
}{
{
input: "localhost:123",
expectedPort: "123",
shouldErr: false,
},
{
input: "localhost",
expectedPort: "80", // assuming 80 is the default port
shouldErr: false,
},
{
input: ":8080",
expectedPort: "8080",
shouldErr: false,
},
{
input: "[::]",
expectedPort: "",
shouldErr: true,
},
}
for _, test := range tests {
testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
}
}
func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
state := getContextOrFail(t)
state.Req.Host = input
var actualResult, testedObject string
var err error
if isTestingHost {
actualResult, err = state.Host()
testedObject = "host"
} else {
actualResult, err = state.Port()
testedObject = "port"
}
if shouldErr && err == nil {
t.Errorf("Expected error, found nil!")
return
}
if !shouldErr && err != nil {
t.Errorf("Expected no error, found %s", err)
return
}
if actualResult != expectedResult {
t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
}
}
func TestPathMatches(t *testing.T) {
state := getContextOrFail(t)
tests := []struct {
urlStr string
pattern string
shouldMatch bool
}{
// Test 0
{
urlStr: "http://localhost/",
pattern: "",
shouldMatch: true,
},
// Test 1
{
urlStr: "http://localhost",
pattern: "",
shouldMatch: true,
},
// Test 1
{
urlStr: "http://localhost/",
pattern: "/",
shouldMatch: true,
},
// Test 3
{
urlStr: "http://localhost/?param=val",
pattern: "/",
shouldMatch: true,
},
// Test 4
{
urlStr: "http://localhost/dir1/dir2",
pattern: "/dir2",
shouldMatch: false,
},
// Test 5
{
urlStr: "http://localhost/dir1/dir2",
pattern: "/dir1",
shouldMatch: true,
},
// Test 6
{
urlStr: "http://localhost:444/dir1/dir2",
pattern: "/dir1",
shouldMatch: true,
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
var err error
state.Req.URL, err = url.Parse(test.urlStr)
if err != nil {
t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
}
matches := state.PathMatches(test.pattern)
if matches != test.shouldMatch {
t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
}
}
}
func initTestContext() (Context, error) {
body := bytes.NewBufferString("request body")
request, err := http.NewRequest("GET", "https://localhost", body)
if err != nil {
return Context{}, err
}
return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
}
func getTestPrefix(testN int) string {
return fmt.Sprintf("Test [%d]: ", testN)
}
*/

View file

@ -1,27 +0,0 @@
package middleware
import "github.com/miekg/dns"
type Zones []string
// Matches checks is qname is a subdomain of any of the zones in z. The match
// will return the most specific zones that matches other. The empty string
// signals a not found condition.
func (z Zones) Matches(qname string) string {
zone := ""
for _, zname := range z {
if dns.IsSubDomain(zname, qname) {
if len(zname) > len(zone) {
zone = zname
}
}
}
return zone
}
// FullyQualify fully qualifies all zones in z.
func (z Zones) FullyQualify() {
for i, _ := range z {
z[i] = dns.Fqdn(z[i])
}
}

View file

@ -1,17 +1,16 @@
package middleware package request
import ( import (
"net" "net"
"strings" "strings"
"time"
"github.com/miekg/coredns/middleware/pkg/edns"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// This file contains the state functions available for use in the middlewares. // Request contains some connection state and is useful in middleware.
type Request struct {
// State contains some connection state and is useful in middleware.
type State struct {
Req *dns.Msg Req *dns.Msg
W dns.ResponseWriter W dns.ResponseWriter
@ -24,38 +23,41 @@ type State struct {
name string name string
} }
// Now returns the current timestamp in the specified format. // NewWithQuestion returns a new request based on the old, but with a new question
func (s *State) Now(format string) string { return time.Now().Format(format) } // section in the request.
func (r *Request) NewWithQuestion(name string, typ uint16) Request {
// NowDate returns the current date/time that can be used in other time functions. req1 := Request{W: r.W, Req: r.Req.Copy()}
func (s *State) NowDate() time.Time { return time.Now() } req1.Req.Question[0] = dns.Question{Name: dns.Fqdn(name), Qtype: dns.ClassINET, Qclass: typ}
return req1
}
// IP gets the (remote) IP address of the client making the request. // IP gets the (remote) IP address of the client making the request.
func (s *State) IP() string { func (r *Request) IP() string {
ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String()) ip, _, err := net.SplitHostPort(r.W.RemoteAddr().String())
if err != nil { if err != nil {
return s.W.RemoteAddr().String() return r.W.RemoteAddr().String()
} }
return ip return ip
} }
// Post gets the (remote) Port of the client making the request. // Post gets the (remote) Port of the client making the request.
func (s *State) Port() (string, error) { func (r *Request) Port() string {
_, port, err := net.SplitHostPort(s.W.RemoteAddr().String()) _, port, err := net.SplitHostPort(r.W.RemoteAddr().String())
if err != nil { if err != nil {
return "0", err return "0"
} }
return port, nil return port
} }
// RemoteAddr returns the net.Addr of the client that sent the current request. // RemoteAddr returns the net.Addr of the client that sent the current request.
func (s *State) RemoteAddr() string { func (r *Request) RemoteAddr() string {
return s.W.RemoteAddr().String() return r.W.RemoteAddr().String()
} }
// Proto gets the protocol used as the transport. This will be udp or tcp. // Proto gets the protocol used as the transport. This will be udp or tcp.
func (s *State) Proto() string { return Proto(s.W) } func (r *Request) Proto() string { return Proto(r.W) }
// FIXME(miek): why not a method on Request
// Proto gets the protocol used as the transport. This will be udp or tcp. // Proto gets the protocol used as the transport. This will be udp or tcp.
func Proto(w dns.ResponseWriter) string { func Proto(w dns.ResponseWriter) string {
if _, ok := w.RemoteAddr().(*net.UDPAddr); ok { if _, ok := w.RemoteAddr().(*net.UDPAddr); ok {
@ -67,11 +69,10 @@ func Proto(w dns.ResponseWriter) string {
return "udp" return "udp"
} }
// Family returns the family of the transport. // Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
// 1 for IPv4 and 2 for IPv6. func (r *Request) Family() int {
func (s *State) Family() int {
var a net.IP var a net.IP
ip := s.W.RemoteAddr() ip := r.W.RemoteAddr()
if i, ok := ip.(*net.UDPAddr); ok { if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP a = i.IP
} }
@ -86,48 +87,49 @@ func (s *State) Family() int {
} }
// Do returns if the request has the DO (DNSSEC OK) bit set. // Do returns if the request has the DO (DNSSEC OK) bit set.
func (s *State) Do() bool { func (r *Request) Do() bool {
if s.do != 0 { if r.do != 0 {
return s.do == doTrue return r.do == doTrue
} }
if o := s.Req.IsEdns0(); o != nil { if o := r.Req.IsEdns0(); o != nil {
if o.Do() { if o.Do() {
s.do = doTrue r.do = doTrue
} else { } else {
s.do = doFalse r.do = doFalse
} }
return o.Do() return o.Do()
} }
s.do = doFalse r.do = doFalse
return false return false
} }
// UDPSize returns if UDP buffer size advertised in the requests OPT record. // Size returns if UDP buffer size advertised in the requests OPT record.
// Or when the request was over TCP, we return the maximum allowed size of 64K. // Or when the request was over TCP, we return the maximum allowed size of 64K.
func (s *State) Size() int { func (r *Request) Size() int {
if s.size != 0 { if r.size != 0 {
return s.size return r.size
} }
size := 0 size := 0
if o := s.Req.IsEdns0(); o != nil { if o := r.Req.IsEdns0(); o != nil {
if o.Do() == true { if o.Do() == true {
s.do = doTrue r.do = doTrue
} else { } else {
s.do = doFalse r.do = doFalse
} }
size = int(o.UDPSize()) size = int(o.UDPSize())
} }
size = edns0Size(s.Proto(), size) // TODO(miek) move edns.Size to dnsutil?
s.size = size size = edns.Size(r.Proto(), size)
r.size = size
return size return size
} }
// SizeAndDo adds an OPT record that the reflects the intent from state. // SizeAndDo adds an OPT record that the reflects the intent from request.
// The returned bool indicated if an record was found and normalised. // The returned bool indicated if an record was found and normalised.
func (s *State) SizeAndDo(m *dns.Msg) bool { func (r *Request) SizeAndDo(m *dns.Msg) bool {
o := s.Req.IsEdns0() // TODO(miek): speed this up o := r.Req.IsEdns0() // TODO(miek): speed this up
if o == nil { if o == nil {
return false return false
} }
@ -163,8 +165,8 @@ const (
// the TC bit will be set regardless of protocol, even TCP message will get the bit, the client // the TC bit will be set regardless of protocol, even TCP message will get the bit, the client
// should then retry with pigeons. // should then retry with pigeons.
// TODO(referral). // TODO(referral).
func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
size := s.Size() size := r.Size()
l := reply.Len() l := reply.Len()
if size >= l { if size >= l {
return reply, ScrubIgnored return reply, ScrubIgnored
@ -173,7 +175,7 @@ func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
// If not delegation, drop additional section. // If not delegation, drop additional section.
reply.Extra = nil reply.Extra = nil
s.SizeAndDo(reply) r.SizeAndDo(reply)
l = reply.Len() l = reply.Len()
if size >= l { if size >= l {
return reply, ScrubDone return reply, ScrubDone
@ -184,42 +186,42 @@ func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
} }
// Type returns the type of the question as a string. // Type returns the type of the question as a string.
func (s *State) Type() string { return dns.Type(s.Req.Question[0].Qtype).String() } func (r *Request) Type() string { return dns.Type(r.Req.Question[0].Qtype).String() }
// QType returns the type of the question as an uint16. // QType returns the type of the question as an uint16.
func (s *State) QType() uint16 { return s.Req.Question[0].Qtype } func (r *Request) QType() uint16 { return r.Req.Question[0].Qtype }
// Name returns the name of the question in the request. Note // Name returns the name of the question in the request. Note
// this name will always have a closing dot and will be lower cased. After a call Name // this name will always have a closing dot and will be lower cased. After a call Name
// the value will be cached. To clear this caching call Clear. // the value will be cached. To clear this caching call Clear.
func (s *State) Name() string { func (r *Request) Name() string {
if s.name != "" { if r.name != "" {
return s.name return r.name
} }
s.name = strings.ToLower(dns.Name(s.Req.Question[0].Name).String()) r.name = strings.ToLower(dns.Name(r.Req.Question[0].Name).String())
return s.name return r.name
} }
// QName returns the name of the question in the request. // QName returns the name of the question in the request.
func (s *State) QName() string { return dns.Name(s.Req.Question[0].Name).String() } func (r *Request) QName() string { return dns.Name(r.Req.Question[0].Name).String() }
// Class returns the class of the question in the request. // Class returns the class of the question in the request.
func (s *State) Class() string { return dns.Class(s.Req.Question[0].Qclass).String() } func (r *Request) Class() string { return dns.Class(r.Req.Question[0].Qclass).String() }
// QClass returns the class of the question in the request. // QClass returns the class of the question in the request.
func (s *State) QClass() uint16 { return s.Req.Question[0].Qclass } func (r *Request) QClass() uint16 { return r.Req.Question[0].Qclass }
// ErrorMessage returns an error message suitable for sending // ErrorMessage returns an error message suitable for sending
// back to the client. // back to the client.
func (s *State) ErrorMessage(rcode int) *dns.Msg { func (r *Request) ErrorMessage(rcode int) *dns.Msg {
m := new(dns.Msg) m := new(dns.Msg)
m.SetRcode(s.Req, rcode) m.SetRcode(r.Req, rcode)
return m return m
} }
// Clear clears all caching from State s. // Clear clears all caching from Request s.
func (s *State) Clear() { func (r *Request) Clear() {
s.name = "" r.name = ""
} }
const ( const (

55
request/request_test.go Normal file
View file

@ -0,0 +1,55 @@
package request
import (
"testing"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
)
func TestRequestDo(t *testing.T) {
st := testRequest()
st.Do()
if st.do == 0 {
t.Fatalf("Expected st.do to be set")
}
}
func TestRequestRemote(t *testing.T) {
st := testRequest()
if st.IP() != "10.240.0.1" {
t.Fatalf("Wrong IP from request")
}
p := st.Port()
if p == "" {
t.Fatalf("Failed to get Port from request")
}
if p != "40212" {
t.Fatalf("Wrong port from request")
}
}
func BenchmarkRequestDo(b *testing.B) {
st := testRequest()
for i := 0; i < b.N; i++ {
st.Do()
}
}
func BenchmarkRequestSize(b *testing.B) {
st := testRequest()
for i := 0; i < b.N; i++ {
st.Size()
}
}
func testRequest() Request {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.SetEdns0(4097, true)
return Request{W: &test.ResponseWriter{}, Req: m}
}

View file

@ -9,11 +9,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd" "github.com/miekg/coredns/middleware/etcd"
"github.com/miekg/coredns/middleware/etcd/msg" "github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
etcdc "github.com/coreos/etcd/client" etcdc "github.com/coreos/etcd/client"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -67,7 +67,7 @@ func TestEtcdStubAndProxyLookup(t *testing.T) {
} }
p := proxy.New([]string{udp}) // use udp port from the server p := proxy.New([]string{udp}) // use udp port from the server
state := middleware.State{W: &test.ResponseWriter{}, Req: new(dns.Msg)} state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
resp, err := p.Lookup(state, "example.com.", dns.TypeA) resp, err := p.Lookup(state, "example.com.", dns.TypeA)
if err != nil { if err != nil {
t.Error("Expected to receive reply, but didn't") t.Error("Expected to receive reply, but didn't")

View file

@ -5,9 +5,9 @@ import (
"log" "log"
"testing" "testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/middleware/test" "github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -46,7 +46,7 @@ func TestLookupProxy(t *testing.T) {
log.SetOutput(ioutil.Discard) log.SetOutput(ioutil.Discard)
p := proxy.New([]string{udp}) p := proxy.New([]string{udp})
state := middleware.State{W: &test.ResponseWriter{}, Req: new(dns.Msg)} state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
resp, err := p.Lookup(state, "example.org.", dns.TypeA) resp, err := p.Lookup(state, "example.org.", dns.TypeA)
if err != nil { if err != nil {
t.Fatal("Expected to receive reply, but didn't") t.Fatal("Expected to receive reply, but didn't")

View file

@ -17,7 +17,7 @@ func TestProxyToChaosServer(t *testing.T) {
t.Fatalf("could not get CoreDNS serving instance: %s", err) t.Fatalf("could not get CoreDNS serving instance: %s", err)
} }
udpChaos, tcpChaos := CoreDNSServerPorts(chaos, 0) udpChaos, _ := CoreDNSServerPorts(chaos, 0)
defer chaos.Stop() defer chaos.Stop()
corefileProxy := `.:0 { corefileProxy := `.:0 {
@ -32,24 +32,23 @@ func TestProxyToChaosServer(t *testing.T) {
udp, _ := CoreDNSServerPorts(proxy, 0) udp, _ := CoreDNSServerPorts(proxy, 0)
defer proxy.Stop() defer proxy.Stop()
chaosTest(t, udpChaos, "udp") chaosTest(t, udpChaos)
chaosTest(t, tcpChaos, "tcp")
chaosTest(t, udp, "udp") chaosTest(t, udp)
// chaosTest(t, tcp, "tcp"), commented out because we use the original transport to reach the // chaosTest(t, tcp, "tcp"), commented out because we use the original transport to reach the
// proxy and we only forward to the udp port. // proxy and we only forward to the udp port.
} }
func chaosTest(t *testing.T, server, net string) { func chaosTest(t *testing.T, server string) {
m := Msg("version.bind.", dns.TypeTXT, nil) m := Msg("version.bind.", dns.TypeTXT, nil)
m.Question[0].Qclass = dns.ClassCHAOS m.Question[0].Qclass = dns.ClassCHAOS
r, err := Exchange(m, server, net) r, err := dns.Exchange(m, server)
if err != nil { if err != nil {
t.Fatalf("Could not send message: %s", err) t.Fatalf("Could not send message: %s", err)
} }
if r.Rcode != dns.RcodeSuccess || len(r.Answer) == 0 { if r.Rcode != dns.RcodeSuccess || len(r.Answer) == 0 {
t.Fatalf("Expected successful reply on %s, got %s", net, dns.RcodeToString[r.Rcode]) t.Fatalf("Expected successful reply, got %s", dns.RcodeToString[r.Rcode])
} }
if r.Answer[0].String() != `version.bind. 0 CH TXT "CoreDNS-001"` { if r.Answer[0].String() != `version.bind. 0 CH TXT "CoreDNS-001"` {
t.Fatalf("Expected version.bind. reply, got %s", r.Answer[0].String()) t.Fatalf("Expected version.bind. reply, got %s", r.Answer[0].String())

View file

@ -1,10 +1,6 @@
package test package test
import ( import "github.com/miekg/dns"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
func Msg(zone string, typ uint16, o *dns.OPT) *dns.Msg { func Msg(zone string, typ uint16, o *dns.OPT) *dns.Msg {
m := new(dns.Msg) m := new(dns.Msg)
@ -14,9 +10,3 @@ func Msg(zone string, typ uint16, o *dns.OPT) *dns.Msg {
} }
return m return m
} }
func Exchange(m *dns.Msg, server, net string) (*dns.Msg, error) {
c := new(dns.Client)
c.Net = net
return middleware.Exchange(c, m, server)
}