Cleanup: put middleware helper functions in pkgs (#245)

Move all (almost all) Go files in middleware into their
own packages. This makes for better naming and discoverability.

Lot of changes elsewhere to make this change.

The middleware.State was renamed to request.Request which is better,
but still does not cover all use-cases. It was also moved out middleware
because it is used by `dnsserver` as well.

A pkg/dnsutil packages was added for shared, handy, dns util functions.

All normalize functions are now put in normalize.go
This commit is contained in:
Miek Gieben 2016-09-07 11:10:16 +01:00 committed by GitHub
parent 684330fd28
commit d1f17fa7e0
90 changed files with 680 additions and 1037 deletions

View file

@ -7,7 +7,8 @@ import (
"sync"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/edns"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
@ -163,7 +164,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
}()
if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once.
if m, err := edns.Version(r); err != nil { // Wrong EDNS version, return at once.
w.WriteMsg(m)
return
}
@ -214,10 +215,11 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// DefaultErrorFunc responds to an DNS request with an error.
func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
answer := new(dns.Msg)
answer.SetRcode(r, rcode)
state.SizeAndDo(answer)
w.WriteMsg(answer)

View file

@ -6,6 +6,7 @@ import (
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/response"
"github.com/miekg/dns"
gcache "github.com/patrickmn/go-cache"
@ -23,7 +24,7 @@ func NewCache(ttl int, zones []string, next middleware.Handler) Cache {
return Cache{Next: next, Zones: zones, cache: gcache.New(defaultDuration, purgeDuration), cap: time.Duration(ttl) * time.Second}
}
func cacheKey(m *dns.Msg, t middleware.MsgType, do bool) string {
func cacheKey(m *dns.Msg, t response.Type, do bool) string {
if m.Truncated {
return ""
}
@ -31,15 +32,15 @@ func cacheKey(m *dns.Msg, t middleware.MsgType, do bool) string {
qtype := m.Question[0].Qtype
qname := strings.ToLower(m.Question[0].Name)
switch t {
case middleware.Success:
case response.Success:
fallthrough
case middleware.Delegation:
case response.Delegation:
return successKey(qname, qtype, do)
case middleware.NameError:
case response.NameError:
return nameErrorKey(qname, do)
case middleware.NoData:
case response.NoData:
return noDataKey(qname, qtype, do)
case middleware.OtherError:
case response.OtherError:
return ""
}
return ""
@ -57,7 +58,7 @@ func NewCachingResponseWriter(w dns.ResponseWriter, cache *gcache.Cache, cap tim
func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error {
do := false
mt, opt := middleware.Classify(res)
mt, opt := response.Classify(res)
if opt != nil {
do = opt.Do()
}
@ -72,7 +73,7 @@ func (c *CachingResponseWriter) WriteMsg(res *dns.Msg) error {
return c.ResponseWriter.WriteMsg(res)
}
func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt middleware.MsgType) {
func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt response.Type) {
if key == "" {
log.Printf("[ERROR] Caching called with empty cache key")
return
@ -80,21 +81,21 @@ func (c *CachingResponseWriter) set(m *dns.Msg, key string, mt middleware.MsgTyp
duration := c.cap
switch mt {
case middleware.Success, middleware.Delegation:
case response.Success, response.Delegation:
if c.cap == 0 {
duration = minTtl(m.Answer, mt)
}
i := newItem(m, duration)
c.cache.Set(key, i, duration)
case middleware.NameError, middleware.NoData:
case response.NameError, response.NoData:
if c.cap == 0 {
duration = minTtl(m.Ns, mt)
}
i := newItem(m, duration)
c.cache.Set(key, i, duration)
case middleware.OtherError:
case response.OtherError:
// don't cache these
default:
log.Printf("[WARNING] Caching called with unknown middleware MsgType: %d", mt)
@ -112,19 +113,19 @@ func (c *CachingResponseWriter) Hijack() {
return
}
func minTtl(rrs []dns.RR, mt middleware.MsgType) time.Duration {
if mt != middleware.Success && mt != middleware.NameError && mt != middleware.NoData {
func minTtl(rrs []dns.RR, mt response.Type) time.Duration {
if mt != response.Success && mt != response.NameError && mt != response.NoData {
return 0
}
minTtl := maxTtl
for _, r := range rrs {
switch mt {
case middleware.NameError, middleware.NoData:
case response.NameError, response.NoData:
if r.Header().Rrtype == dns.TypeSOA {
return time.Duration(r.(*dns.SOA).Minttl) * time.Second
}
case middleware.Success, middleware.Delegation:
case response.Success, response.Delegation:
if r.Header().Ttl < minTtl {
minTtl = r.Header().Ttl
}

View file

@ -5,6 +5,7 @@ import (
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/response"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -78,7 +79,7 @@ func TestCache(t *testing.T) {
m = cacheMsg(m, tc)
do := tc.in.Do
mt, _ := middleware.Classify(m)
mt, _ := response.Classify(m)
key := cacheKey(m, mt, do)
crr.set(m, key, mt)

View file

@ -2,6 +2,7 @@ package cache
import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
@ -10,7 +11,7 @@ import (
// ServeDNS implements the middleware.Handler interface.
func (c Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
qname := state.Name()
qtype := state.QType()

View file

@ -4,6 +4,7 @@ import (
"os"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
@ -18,7 +19,7 @@ type Chaos struct {
}
func (c Chaos) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
if state.QClass() != dns.ClassCHAOS || state.QType() != dns.TypeTXT {
return c.Next.ServeDNS(ctx, w, r)
}

View file

@ -4,6 +4,7 @@ import (
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -58,7 +59,7 @@ func TestChaos(t *testing.T) {
req.Question[0].Qclass = dns.ClassCHAOS
em.Next = tc.next
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
code, err := em.ServeDNS(ctx, rec, req)
if err != tc.expectedErr {
@ -68,7 +69,7 @@ func TestChaos(t *testing.T) {
t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code)
}
if tc.expectedReply != "" {
answer := rec.Msg().Answer[0].(*dns.TXT).Txt[0]
answer := rec.Msg.Answer[0].(*dns.TXT).Txt[0]
if answer != tc.expectedReply {
t.Errorf("Test %d: Expected answer %s, but got %s", i, tc.expectedReply, answer)
}

View file

@ -4,8 +4,8 @@ import (
"testing"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -16,7 +16,7 @@ func TestZoneSigningBlackLies(t *testing.T) {
defer rm2()
m := testNxdomainMsg()
state := middleware.State{Req: m}
state := request.Request{Req: m}
m = d.Sign(state, "miek.nl.", time.Now().UTC())
if !section(m.Ns, 2) {
t.Errorf("authority section should have 2 sig")

View file

@ -4,8 +4,8 @@ import (
"testing"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
)
func TestCacheSet(t *testing.T) {
@ -20,7 +20,7 @@ func TestCacheSet(t *testing.T) {
}
m := testMsg()
state := middleware.State{Req: m}
state := request.Request{Req: m}
k := key(m.Answer) // calculate *before* we add the sig
d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil)
m = d.Sign(state, "miek.nl.", time.Now().UTC())

View file

@ -8,7 +8,7 @@ import (
"os"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -50,7 +50,7 @@ func ParseKeyFile(pubFile, privFile string) (*DNSKEY, error) {
}
// getDNSKEY returns the correct DNSKEY to the client. Signatures are added when do is true.
func (d Dnssec) getDNSKEY(state middleware.State, zone string, do bool) *dns.Msg {
func (d Dnssec) getDNSKEY(state request.Request, zone string, do bool) *dns.Msg {
keys := make([]dns.RR, len(d.keys))
for i, k := range d.keys {
keys[i] = dns.Copy(k.K)

View file

@ -4,7 +4,9 @@ import (
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/singleflight"
"github.com/miekg/coredns/middleware/pkg/response"
"github.com/miekg/coredns/middleware/pkg/singleflight"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
gcache "github.com/patrickmn/go-cache"
@ -28,20 +30,21 @@ func New(zones []string, keys []*DNSKEY, next middleware.Handler) Dnssec {
}
}
// Sign signs the message m. it takes care of negative or nodata responses. It
// Sign signs the message in state. it takes care of negative or nodata responses. It
// uses NSEC black lies for authenticated denial of existence. Signatures
// creates will be cached for a short while. By default we sign for 8 days,
// starting 3 hours ago.
func (d Dnssec) Sign(state middleware.State, zone string, now time.Time) *dns.Msg {
func (d Dnssec) Sign(state request.Request, zone string, now time.Time) *dns.Msg {
req := state.Req
mt, _ := middleware.Classify(req) // TODO(miek): need opt record here?
if mt == middleware.Delegation {
mt, _ := response.Classify(req) // TODO(miek): need opt record here?
if mt == response.Delegation {
return req
}
incep, expir := incepExpir(now)
if mt == middleware.NameError {
if mt == response.NameError {
if req.Ns[0].Header().Rrtype != dns.TypeSOA || len(req.Ns) > 1 {
return req
}

View file

@ -4,8 +4,8 @@ import (
"testing"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -16,7 +16,7 @@ func TestZoneSigning(t *testing.T) {
defer rm2()
m := testMsg()
state := middleware.State{Req: m}
state := request.Request{Req: m}
m = d.Sign(state, "miek.nl.", time.Now().UTC())
if !section(m.Answer, 1) {
@ -44,7 +44,7 @@ func TestZoneSigningDouble(t *testing.T) {
d.keys = append(d.keys, key1)
m := testMsg()
state := middleware.State{Req: m}
state := request.Request{Req: m}
m = d.Sign(state, "miek.nl.", time.Now().UTC())
if !section(m.Answer, 2) {
t.Errorf("answer section should have 1 sig")
@ -68,7 +68,7 @@ func TestSigningDifferentZone(t *testing.T) {
}
m := testMsgEx()
state := middleware.State{Req: m}
state := request.Request{Req: m}
d := New([]string{"example.org."}, []*DNSKEY{key}, nil)
m = d.Sign(state, "example.org.", time.Now().UTC())
if !section(m.Answer, 1) {
@ -86,7 +86,7 @@ func TestSigningCname(t *testing.T) {
defer rm2()
m := testMsgCname()
state := middleware.State{Req: m}
state := request.Request{Req: m}
m = d.Sign(state, "miek.nl.", time.Now().UTC())
if !section(m.Answer, 1) {
t.Errorf("answer section should have 1 sig")
@ -100,7 +100,7 @@ func TestZoneSigningDelegation(t *testing.T) {
defer rm2()
m := testDelegationMsg()
state := middleware.State{Req: m}
state := request.Request{Req: m}
m = d.Sign(state, "miek.nl.", time.Now().UTC())
if !section(m.Ns, 0) {
t.Errorf("authority section should have 0 sig")

View file

@ -2,6 +2,7 @@ package dnssec
import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
@ -10,7 +11,7 @@ import (
// ServeDNS implements the middleware.Handler interface.
func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
do := state.Do()
qname := state.Name()

View file

@ -5,8 +5,8 @@ import (
"strings"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/file"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -83,14 +83,14 @@ func TestLookupZone(t *testing.T) {
for _, tc := range dnsTestCases {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := dh.ServeDNS(ctx, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
return
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))
@ -121,14 +121,14 @@ func TestLookupDNSKEY(t *testing.T) {
for _, tc := range dnssecTestCases {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := dh.ServeDNS(ctx, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
return
}
resp := rec.Msg()
resp := rec.Msg
if !resp.Authoritative {
t.Errorf("Authoritative Answer should be true, got false")
}

View file

@ -5,6 +5,8 @@ import (
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -20,7 +22,7 @@ func NewDnssecResponseWriter(w dns.ResponseWriter, d Dnssec) *DnssecResponseWrit
func (d *DnssecResponseWriter) WriteMsg(res *dns.Msg) error {
// By definition we should sign anything that comes back, we should still figure out for
// which zone it should be.
state := middleware.State{W: d.ResponseWriter, Req: res}
state := request.Request{W: d.ResponseWriter, Req: res}
qname := state.Name()
zone := middleware.Zones(d.d.zones).Matches(qname)

View file

@ -9,6 +9,8 @@ import (
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/roller"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
@ -19,7 +21,7 @@ type ErrorHandler struct {
Next middleware.Handler
LogFile string
Log *log.Logger
LogRoller *middleware.LogRoller
LogRoller *roller.LogRoller
Debug bool // if true, errors are written out to client rather than to a log
}
@ -29,7 +31,7 @@ func (h ErrorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
rcode, err := h.Next.ServeDNS(ctx, w, r)
if err != nil {
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, state.Name(), state.Type(), err)
if h.Debug {
@ -53,7 +55,7 @@ func (h ErrorHandler) recovery(ctx context.Context, w dns.ResponseWriter, r *dns
return
}
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
// Obtain source of panic
// From: https://gist.github.com/swdunlop/9629168
var name, file string // function name, file name

View file

@ -9,6 +9,7 @@ import (
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -47,7 +48,7 @@ func TestErrors(t *testing.T) {
for i, tc := range tests {
em.Next = tc.next
buf.Reset()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
code, err := em.ServeDNS(ctx, rec, req)
if err != tc.expectedErr {
@ -78,7 +79,7 @@ func TestVisibleErrorWithPanic(t *testing.T) {
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
code, err := eh.ServeDNS(ctx, rec, req)
if code != 0 {

View file

@ -6,7 +6,7 @@ import (
"os"
"github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/roller"
"github.com/hashicorp/go-syslog"
"github.com/mholt/caddy"
@ -93,7 +93,7 @@ func errorsParse(c *caddy.Controller) (ErrorHandler, error) {
if c.NextArg() {
if c.Val() == "{" {
c.IncrNest()
logRoller, err := middleware.ParseRoller(c)
logRoller, err := roller.Parse(c)
if err != nil {
return hadBlock, err
}

View file

@ -4,7 +4,7 @@ import (
"testing"
"github.com/mholt/caddy"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/roller"
)
func TestErrorsParse(t *testing.T) {
@ -29,7 +29,7 @@ func TestErrorsParse(t *testing.T) {
}},
{`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, ErrorHandler{
LogFile: "errors.txt",
LogRoller: &middleware.LogRoller{
LogRoller: &roller.LogRoller{
MaxSize: 2,
MaxAge: 10,
MaxBackups: 3,
@ -43,7 +43,7 @@ func TestErrorsParse(t *testing.T) {
}
}`, false, ErrorHandler{
LogFile: "errors.txt",
LogRoller: &middleware.LogRoller{
LogRoller: &roller.LogRoller{
MaxSize: 3,
MaxAge: 11,
MaxBackups: 5,

View file

@ -7,8 +7,8 @@ package etcd
import (
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -25,14 +25,14 @@ func TestCnameLookup(t *testing.T) {
for _, tc := range dnsTestCasesCname {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
return
}
resp := rec.Msg()
resp := rec.Msg
if !test.Header(t, tc, resp) {
t.Logf("%v\n", resp)
continue

View file

@ -7,8 +7,8 @@ import (
"strings"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -41,14 +41,14 @@ func TestDebugLookup(t *testing.T) {
for _, tc := range dnsTestCasesDebug {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
continue
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))
@ -79,14 +79,14 @@ func TestDebugLookupFalse(t *testing.T) {
for _, tc := range dnsTestCasesDebugFalse {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
continue
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))

View file

@ -9,8 +9,8 @@ import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/singleflight"
"github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/singleflight"
etcdc "github.com/coreos/etcd/client"
"golang.org/x/net/context"

View file

@ -6,8 +6,8 @@ import (
"sort"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -23,14 +23,14 @@ func TestGroupLookup(t *testing.T) {
for _, tc := range dnsTestCasesGroup {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
continue
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))

View file

@ -5,6 +5,7 @@ import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
@ -12,7 +13,7 @@ import (
func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
opt := Options{}
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
if state.QClass() != dns.ClassINET {
return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET")
}
@ -115,7 +116,7 @@ func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
}
// Err write an error response to the client.
func (e *Etcd) Err(zone string, rcode int, state middleware.State, debug []msg.Service, err error, opt Options) (int, error) {
func (e *Etcd) Err(zone string, rcode int, state request.Request, debug []msg.Service, err error, opt Options) (int, error) {
m := new(dns.Msg)
m.SetRcode(state.Req, rcode)
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true

View file

@ -8,6 +8,8 @@ import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsutil"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -16,7 +18,7 @@ type Options struct {
Debug string
}
func (e Etcd) records(state middleware.State, exact bool, opt Options) (services, debug []msg.Service, err error) {
func (e Etcd) records(state request.Request, exact bool, opt Options) (services, debug []msg.Service, err error) {
services, err = e.Records(state.Name(), exact)
if err != nil {
return
@ -28,7 +30,7 @@ func (e Etcd) records(state middleware.State, exact bool, opt Options) (services
return
}
func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) {
func (e Etcd) A(zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, false, opt)
if err != nil {
return nil, debug, err
@ -49,11 +51,11 @@ func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, o
// don't add it, and just continue
continue
}
if isDuplicateCNAME(newRecord, previousRecords) {
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue
}
state1 := copyState(state, serv.Host, state.QType())
state1 := state.NewWithQuestion(serv.Host, state.QType())
nextRecords, nextDebug, err := e.A(zone, state1, append(previousRecords, newRecord), opt)
if err == nil {
@ -90,7 +92,7 @@ func (e Etcd) A(zone string, state middleware.State, previousRecords []dns.RR, o
return records, debug, nil
}
func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) {
func (e Etcd) AAAA(zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, false, opt)
if err != nil {
return nil, debug, err
@ -111,11 +113,11 @@ func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR
// don't add it, and just continue
continue
}
if isDuplicateCNAME(newRecord, previousRecords) {
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue
}
state1 := copyState(state, serv.Host, state.QType())
state1 := state.NewWithQuestion(serv.Host, state.QType())
nextRecords, nextDebug, err := e.AAAA(zone, state1, append(previousRecords, newRecord), opt)
if err == nil {
@ -155,7 +157,7 @@ func (e Etcd) AAAA(zone string, state middleware.State, previousRecords []dns.RR
// SRV returns SRV records from etcd.
// If the Target is not a name but an IP address, a name is created on the fly.
func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) {
func (e Etcd) SRV(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, false, opt)
if err != nil {
return nil, nil, nil, err
@ -220,7 +222,7 @@ func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, ex
}
// Internal name, we should have some info on them, either v4 or v6
// Clients expect a complete answer, because we are a recursor in their view.
state1 := copyState(state, srv.Target, dns.TypeA)
state1 := state.NewWithQuestion(srv.Target, dns.TypeA)
addr, debugAddr, e1 := e.A(zone, state1, nil, opt)
if e1 == nil {
extra = append(extra, addr...)
@ -246,7 +248,7 @@ func (e Etcd) SRV(zone string, state middleware.State, opt Options) (records, ex
// MX returns MX records from etcd.
// If the Target is not a name but an IP address, a name is created on the fly.
func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) {
func (e Etcd) MX(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, false, opt)
if err != nil {
return nil, nil, debug, err
@ -291,7 +293,7 @@ func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, ext
break
}
// Internal name
state1 := copyState(state, mx.Mx, dns.TypeA)
state1 := state.NewWithQuestion(mx.Mx, dns.TypeA)
addr, debugAddr, e1 := e.A(zone, state1, nil, opt)
if e1 == nil {
extra = append(extra, addr...)
@ -311,7 +313,7 @@ func (e Etcd) MX(zone string, state middleware.State, opt Options) (records, ext
return records, extra, debug, nil
}
func (e Etcd) CNAME(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) {
func (e Etcd) CNAME(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, true, opt)
if err != nil {
return nil, debug, err
@ -327,7 +329,7 @@ func (e Etcd) CNAME(zone string, state middleware.State, opt Options) (records [
}
// PTR returns the PTR records, only services that have a domain name as host are included.
func (e Etcd) PTR(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) {
func (e Etcd) PTR(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, true, opt)
if err != nil {
return nil, debug, err
@ -341,7 +343,7 @@ func (e Etcd) PTR(zone string, state middleware.State, opt Options) (records []d
return records, debug, nil
}
func (e Etcd) TXT(zone string, state middleware.State, opt Options) (records []dns.RR, debug []msg.Service, err error) {
func (e Etcd) TXT(zone string, state request.Request, opt Options) (records []dns.RR, debug []msg.Service, err error) {
services, debug, err := e.records(state, false, opt)
if err != nil {
return nil, debug, err
@ -356,7 +358,7 @@ func (e Etcd) TXT(zone string, state middleware.State, opt Options) (records []d
return records, debug, nil
}
func (e Etcd) NS(zone string, state middleware.State, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) {
func (e Etcd) NS(zone string, state request.Request, opt Options) (records, extra []dns.RR, debug []msg.Service, err error) {
// NS record for this zone live in a special place, ns.dns.<zone>. Fake our lookup.
// only a tad bit fishy...
old := state.QName()
@ -389,7 +391,7 @@ func (e Etcd) NS(zone string, state middleware.State, opt Options) (records, ext
}
// SOA Record returns a SOA record.
func (e Etcd) SOA(zone string, state middleware.State, opt Options) ([]dns.RR, []msg.Service, error) {
func (e Etcd) SOA(zone string, state request.Request, opt Options) ([]dns.RR, []msg.Service, error) {
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: 300, Class: dns.ClassINET}
soa := &dns.SOA{Hdr: header,
@ -404,21 +406,3 @@ func (e Etcd) SOA(zone string, state middleware.State, opt Options) ([]dns.RR, [
// TODO(miek): fake some msg.Service here when returning.
return []dns.RR{soa}, nil, nil
}
func isDuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
for _, rec := range records {
if v, ok := rec.(*dns.CNAME); ok {
if v.Target == r.Target {
return true
}
}
}
return false
}
// TODO(miek): Move to middleware?
func copyState(state middleware.State, target string, typ uint16) middleware.State {
state1 := middleware.State{W: state.W, Req: state.Req.Copy()}
state1.Req.Question[0] = dns.Question{Name: dns.Fqdn(target), Qclass: dns.ClassINET, Qtype: typ}
return state1
}

View file

@ -6,8 +6,8 @@ import (
"sort"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -25,14 +25,14 @@ func TestMultiLookup(t *testing.T) {
for _, tc := range dnsTestCasesMulti {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
return
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))

View file

@ -10,8 +10,8 @@ import (
"strings"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -27,14 +27,14 @@ func TestOtherLookup(t *testing.T) {
for _, tc := range dnsTestCasesOther {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
continue
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))

View file

@ -6,8 +6,8 @@ import (
"sort"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/middleware/test"
@ -27,14 +27,14 @@ func TestProxyLookupFailDebug(t *testing.T) {
for _, tc := range dnsTestCasesProxy {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
continue
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))

View file

@ -10,8 +10,8 @@ import (
"github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/singleflight"
"github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/singleflight"
etcdc "github.com/coreos/etcd/client"
"github.com/mholt/caddy"
@ -70,7 +70,7 @@ func etcdParse(c *caddy.Controller) (*Etcd, bool, error) {
etc.Zones = make([]string, len(c.ServerBlockKeys))
copy(etc.Zones, c.ServerBlockKeys)
}
middleware.Zones(etc.Zones).FullyQualify()
middleware.Zones(etc.Zones).Normalize()
if c.NextBlock() {
// TODO(miek): 2 switches?
switch c.Val() {

View file

@ -8,11 +8,11 @@ import (
"testing"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/pkg/singleflight"
"github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/singleflight"
etcdc "github.com/coreos/etcd/client"
"github.com/miekg/dns"
@ -65,14 +65,14 @@ func TestLookup(t *testing.T) {
for _, tc := range dnsTestCases {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m)
if err != nil {
t.Errorf("expected no error, got: %v for %s %s\n", err, m.Question[0].Name, dns.Type(m.Question[0].Qtype))
return
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))

View file

@ -4,7 +4,7 @@ import (
"errors"
"log"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
@ -27,7 +27,7 @@ func (s Stub) ServeDNS(ctx context.Context, w dns.ResponseWriter, req *dns.Msg)
return dns.RcodeServerFailure, nil
}
state := middleware.State{W: w, Req: req}
state := request.Request{W: w, Req: req}
m, e := proxy.Forward(state)
if e != nil {
return dns.RcodeServerFailure, e

View file

@ -8,8 +8,8 @@ import (
"strconv"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -53,7 +53,7 @@ func TestStubLookup(t *testing.T) {
for _, tc := range dnsTestCasesStub {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := etc.ServeDNS(ctxt, rec, m)
if err != nil && m.Question[0].Name == "example.org." {
// This is OK, we expect this backend to *not* work.
@ -62,12 +62,11 @@ func TestStubLookup(t *testing.T) {
if err != nil {
t.Errorf("expected no error, got %v for %s\n", err, m.Question[0].Name)
}
resp := rec.Msg()
resp := rec.Msg
if resp == nil {
// etcd not running?
continue
}
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))

View file

@ -1,8 +0,0 @@
package middleware
import "github.com/miekg/dns"
func Exchange(c *dns.Client, m *dns.Msg, server string) (*dns.Msg, error) {
r, _, err := c.Exchange(m, server)
return r, err
}

View file

@ -5,7 +5,7 @@ import (
"strings"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -57,14 +57,14 @@ func TestLookupDelegation(t *testing.T) {
for _, tc := range delegationTestCases {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := fm.ServeDNS(ctx, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
return
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))

View file

@ -5,7 +5,7 @@ import (
"strings"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -116,14 +116,14 @@ func TestLookupDNSSEC(t *testing.T) {
for _, tc := range dnssecTestCases {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := fm.ServeDNS(ctx, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
return
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))
@ -154,7 +154,7 @@ func BenchmarkLookupDNSSEC(b *testing.B) {
fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}}
ctx := context.TODO()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
tc := test.Case{
Qname: "b.miek.nl.", Qtype: dns.TypeA, Do: true,

View file

@ -5,7 +5,7 @@ import (
"strings"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -43,14 +43,14 @@ func TestLookupENT(t *testing.T) {
for _, tc := range entTestCases {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := fm.ServeDNS(ctx, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
return
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))

View file

@ -6,6 +6,7 @@ import (
"log"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
@ -24,7 +25,7 @@ type (
)
func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
if state.QClass() != dns.ClassINET {
return dns.RcodeServerFailure, errors.New("can only deal with ClassINET")

View file

@ -5,7 +5,7 @@ import (
"strings"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -87,14 +87,14 @@ func TestLookup(t *testing.T) {
for _, tc := range dnsTestCases {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := fm.ServeDNS(ctx, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
return
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))
@ -122,7 +122,7 @@ func TestLookupNil(t *testing.T) {
ctx := context.TODO()
m := dnsTestCases[0].Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
fm.ServeDNS(ctx, rec, m)
}
@ -134,7 +134,7 @@ func BenchmarkLookup(b *testing.B) {
fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}}
ctx := context.TODO()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
tc := test.Case{
Qname: "www.miek.nl.", Qtype: dns.TypeA,

View file

@ -5,6 +5,7 @@ import (
"log"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -12,7 +13,7 @@ import (
// isNotify checks if state is a notify message and if so, will *also* check if it
// is from one of the configured masters. If not it will not be a valid notify
// message. If the zone z is not a secondary zone the message will also be ignored.
func (z *Zone) isNotify(state middleware.State) bool {
func (z *Zone) isNotify(state request.Request) bool {
if state.Req.Opcode != dns.OpcodeNotify {
return false
}
@ -56,7 +57,7 @@ func notify(zone string, to []string) error {
func notifyAddr(c *dns.Client, m *dns.Msg, s string) error {
for i := 0; i < 3; i++ {
ret, err := middleware.Exchange(c, m, s)
ret, _, err := c.Exchange(m, s)
if err != nil {
continue
}

View file

@ -4,8 +4,6 @@ import (
"log"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
@ -75,7 +73,7 @@ func (z *Zone) shouldTransfer() (bool, error) {
Transfer:
for _, tr := range z.TransferFrom {
Err = nil
ret, err := middleware.Exchange(c, m, tr)
ret, _, err := c.Exchange(m, tr)
if err != nil || ret.Rcode != dns.RcodeSuccess {
Err = err
continue

View file

@ -4,8 +4,8 @@ import (
"fmt"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -134,7 +134,7 @@ func TestIsNotify(t *testing.T) {
z := new(Zone)
z.Expired = new(bool)
z.origin = testZone
state := NewState(testZone, dns.TypeSOA)
state := newRequest(testZone, dns.TypeSOA)
// need to set opcode
state.Req.Opcode = dns.OpcodeNotify
@ -148,9 +148,9 @@ func TestIsNotify(t *testing.T) {
}
}
func NewState(zone string, qtype uint16) middleware.State {
func newRequest(zone string, qtype uint16) request.Request {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.SetEdns0(4097, true)
return middleware.State{W: &test.ResponseWriter{}, Req: m}
return request.Request{W: &test.ResponseWriter{}, Req: m}
}

View file

@ -1,9 +1,6 @@
package tree
import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
import "github.com/miekg/dns"
type Elem struct {
m map[uint16][]dns.RR
@ -91,8 +88,8 @@ func (e *Elem) Delete(rr dns.RR) (empty bool) {
return
}
// Less is a tree helper function that calls middleware.Less.
func Less(a *Elem, name string) int { return middleware.Less(name, a.Name()) }
// Less is a tree helper function that calls less.
func Less(a *Elem, name string) int { return less(name, a.Name()) }
// Assuming the same type and name this will check if the rdata is equal as well.
func equalRdata(a, b dns.RR) bool {

View file

@ -1,4 +1,4 @@
package middleware
package tree
import (
"bytes"
@ -6,7 +6,7 @@ import (
"github.com/miekg/dns"
)
// Less returns <0 when a is less than b, 0 when they are equal and
// less returns <0 when a is less than b, 0 when they are equal and
// >0 when a is larger than b.
// The function orders names in DNSSEC canonical order: RFC 4034s section-6.1
//
@ -14,7 +14,7 @@ import (
// for a blog article on this implementation.
//
// The values of a and b are *not* lowercased before the comparison!
func Less(a, b string) int {
func less(a, b string) int {
i := 1
aj := len(a)
bj := len(b)

View file

@ -1,4 +1,4 @@
package middleware
package tree
import (
"sort"
@ -10,7 +10,7 @@ type set []string
func (p set) Len() int { return len(p) }
func (p set) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p set) Less(i, j int) bool { d := Less(p[i], p[j]); return d <= 0 }
func (p set) Less(i, j int) bool { d := less(p[i], p[j]); return d <= 0 }
func TestLess(t *testing.T) {
tests := []struct {

View file

@ -5,7 +5,7 @@ import (
"strings"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -57,14 +57,14 @@ func TestLookupWildcard(t *testing.T) {
for _, tc := range wildcardTestCases {
m := tc.Msg()
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
_, err := fm.ServeDNS(ctx, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
return
}
resp := rec.Msg()
resp := rec.Msg
sort.Sort(test.RRSet(resp.Answer))
sort.Sort(test.RRSet(resp.Ns))
sort.Sort(test.RRSet(resp.Extra))

View file

@ -4,7 +4,7 @@ import (
"fmt"
"log"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
@ -18,7 +18,7 @@ type (
// Serve an AXFR (and fallback of IXFR) as well.
func (x Xfr) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
if !x.TransferAllowed(state) {
return dns.RcodeServerFailure, nil
}

View file

@ -8,8 +8,8 @@ import (
"strings"
"sync"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/file/tree"
"github.com/miekg/coredns/request"
"github.com/fsnotify/fsnotify"
"github.com/miekg/dns"
@ -102,7 +102,7 @@ func (z *Zone) Insert(r dns.RR) error {
func (z *Zone) Delete(r dns.RR) { z.Tree.Delete(r) }
// TransferAllowed checks if incoming request for transferring the zone is allowed according to the ACLs.
func (z *Zone) TransferAllowed(state middleware.State) bool {
func (z *Zone) TransferAllowed(req request.Request) bool {
for _, t := range z.TransferTo {
if t == "*" {
return true

View file

@ -1,19 +0,0 @@
package middleware
import (
"os"
"strings"
"testing"
)
func TestfsPath(t *testing.T) {
if actual := fsPath(); !strings.HasSuffix(actual, ".coredns") {
t.Errorf("Expected path to be a .coredns folder, got: %v", actual)
}
os.Setenv("COREDNSPATH", "testpath")
defer os.Setenv("COREDNSPATH", "")
if actual, expected := fsPath(), "testpath"; actual != expected {
t.Errorf("Expected path to be %v, got: %v", expected, actual)
}
}

View file

@ -1,36 +0,0 @@
package middleware
import (
"net"
"strings"
"github.com/miekg/dns"
)
// Host represents a host from the Corefile, may contain port.
type (
Host string
Addr string
)
// Normalize will return the host portion of host, stripping
// of any port. The host will also be fully qualified and lowercased.
func (h Host) Normalize() string {
// separate host and port
host, _, err := net.SplitHostPort(string(h))
if err != nil {
host, _, _ = net.SplitHostPort(string(h) + ":")
}
return strings.ToLower(dns.Fqdn(host))
}
// Normalize will return a normalized address, if not port is specified
// port 53 is added, otherwise the port will be left as is.
func (a Addr) Normalize() string {
// separate host and port
addr, port, err := net.SplitHostPort(string(a))
if err != nil {
addr, port, _ = net.SplitHostPort(string(a) + ":53")
}
return net.JoinHostPort(addr, port)
}

View file

@ -2,16 +2,17 @@ package kubernetes
import (
"fmt"
"strings"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsutil"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
if state.QClass() != dns.ClassINET {
return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET")
}
@ -21,8 +22,8 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M
m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true
// TODO: find an alternative to this block
if strings.HasSuffix(state.Name(), arpaSuffix) {
ip, _ := extractIP(state.Name())
ip := dnsutil.ExtractAddressFromReverse(state.Name())
if ip != "" {
records := k.getServiceRecordForIP(ip, state.Name())
if len(records) > 0 {
srvPTR := &records[0]
@ -100,7 +101,7 @@ func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.M
}
// NoData write a nodata response to the client.
func (k Kubernetes) Err(zone string, rcode int, state middleware.State) (int, error) {
func (k Kubernetes) Err(zone string, rcode int, state request.Request) (int, error) {
m := new(dns.Msg)
m.SetRcode(state.Req, rcode)
m.Ns = []dns.RR{k.SOA(zone, state)}

View file

@ -4,13 +4,13 @@ package kubernetes
import (
"errors"
"log"
"strings"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/kubernetes/msg"
"github.com/miekg/coredns/middleware/kubernetes/nametemplate"
"github.com/miekg/coredns/middleware/kubernetes/util"
"github.com/miekg/coredns/middleware/pkg/dnsutil"
"github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/dns"
@ -100,8 +100,8 @@ func (k *Kubernetes) getZoneForName(name string) (string, []string) {
func (k *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) {
// TODO: refector this.
// Right now GetNamespaceFromSegmentArray do not supports PRE queries
if strings.HasSuffix(name, arpaSuffix) {
ip, _ := extractIP(name)
ip := dnsutil.ExtractAddressFromReverse(name)
if ip != "" {
records := k.getServiceRecordForIP(ip, name)
return records, nil
}

View file

@ -4,21 +4,17 @@ import (
"fmt"
"math"
"net"
"strings"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/kubernetes/msg"
"github.com/miekg/coredns/middleware/pkg/dnsutil"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
const (
// arpaSuffix is the standard suffix for PTR IP reverse lookups.
arpaSuffix = ".in-addr.arpa."
)
func (k Kubernetes) records(state middleware.State, exact bool) ([]msg.Service, error) {
func (k Kubernetes) records(state request.Request, exact bool) ([]msg.Service, error) {
services, err := k.Records(state.Name(), exact)
if err != nil {
return nil, err
@ -28,7 +24,7 @@ func (k Kubernetes) records(state middleware.State, exact bool) ([]msg.Service,
return services, nil
}
func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) {
func (k Kubernetes) A(zone string, state request.Request, previousRecords []dns.RR) (records []dns.RR, err error) {
services, err := k.records(state, false)
if err != nil {
return nil, err
@ -49,11 +45,11 @@ func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns
// don't add it, and just continue
continue
}
if isDuplicateCNAME(newRecord, previousRecords) {
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue
}
state1 := copyState(state, serv.Host, state.QType())
state1 := state.NewWithQuestion(serv.Host, state.QType())
nextRecords, err := k.A(zone, state1, append(previousRecords, newRecord))
if err == nil {
@ -87,7 +83,7 @@ func (k Kubernetes) A(zone string, state middleware.State, previousRecords []dns
return records, nil
}
func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords []dns.RR) (records []dns.RR, err error) {
func (k Kubernetes) AAAA(zone string, state request.Request, previousRecords []dns.RR) (records []dns.RR, err error) {
services, err := k.records(state, false)
if err != nil {
return nil, err
@ -108,11 +104,11 @@ func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords []
// don't add it, and just continue
continue
}
if isDuplicateCNAME(newRecord, previousRecords) {
if dnsutil.DuplicateCNAME(newRecord, previousRecords) {
continue
}
state1 := copyState(state, serv.Host, state.QType())
state1 := state.NewWithQuestion(serv.Host, state.QType())
nextRecords, err := k.AAAA(zone, state1, append(previousRecords, newRecord))
if err == nil {
@ -149,7 +145,7 @@ func (k Kubernetes) AAAA(zone string, state middleware.State, previousRecords []
// SRV returns SRV records from kubernetes.
// If the Target is not a name but an IP address, a name is created on the fly.
func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR, extra []dns.RR, err error) {
func (k Kubernetes) SRV(zone string, state request.Request) (records []dns.RR, extra []dns.RR, err error) {
services, err := k.records(state, false)
if err != nil {
return nil, nil, err
@ -207,7 +203,7 @@ func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR,
}
// Internal name, we should have some info on them, either v4 or v6
// Clients expect a complete answer, because we are a recursor in their view.
state1 := copyState(state, srv.Target, dns.TypeA)
state1 := state.NewWithQuestion(srv.Target, dns.TypeA)
addr, e1 := k.A(zone, state1, nil)
if e1 == nil {
extra = append(extra, addr...)
@ -231,21 +227,21 @@ func (k Kubernetes) SRV(zone string, state middleware.State) (records []dns.RR,
}
// Returning MX records from kubernetes not implemented.
func (k Kubernetes) MX(zone string, state middleware.State) (records []dns.RR, extra []dns.RR, err error) {
func (k Kubernetes) MX(zone string, state request.Request) (records []dns.RR, extra []dns.RR, err error) {
return nil, nil, err
}
// Returning CNAME records from kubernetes not implemented.
func (k Kubernetes) CNAME(zone string, state middleware.State) (records []dns.RR, err error) {
func (k Kubernetes) CNAME(zone string, state request.Request) (records []dns.RR, err error) {
return nil, err
}
// Returning TXT records from kubernetes not implemented.
func (k Kubernetes) TXT(zone string, state middleware.State) (records []dns.RR, err error) {
func (k Kubernetes) TXT(zone string, state request.Request) (records []dns.RR, err error) {
return nil, err
}
func (k Kubernetes) NS(zone string, state middleware.State) (records, extra []dns.RR, err error) {
func (k Kubernetes) NS(zone string, state request.Request) (records, extra []dns.RR, err error) {
// NS record for this zone live in a special place, ns.dns.<zone>. Fake our lookup.
// only a tad bit fishy...
old := state.QName()
@ -278,7 +274,7 @@ func (k Kubernetes) NS(zone string, state middleware.State) (records, extra []dn
}
// SOA Record returns a SOA record.
func (k Kubernetes) SOA(zone string, state middleware.State) *dns.SOA {
func (k Kubernetes) SOA(zone string, state request.Request) *dns.SOA {
header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: 300, Class: dns.ClassINET}
return &dns.SOA{Hdr: header,
Mbox: "hostmaster." + zone,
@ -291,9 +287,9 @@ func (k Kubernetes) SOA(zone string, state middleware.State) *dns.SOA {
}
}
func (k Kubernetes) PTR(zone string, state middleware.State) ([]dns.RR, error) {
reverseIP, ok := extractIP(state.Name())
if !ok {
func (k Kubernetes) PTR(zone string, state request.Request) ([]dns.RR, error) {
reverseIP := dnsutil.ExtractAddressFromReverse(state.Name())
if reverseIP == "" {
return nil, fmt.Errorf("does not support reverse lookup for %s", state.QName())
}
@ -318,41 +314,3 @@ func (k Kubernetes) PTR(zone string, state middleware.State) ([]dns.RR, error) {
}
return records, nil
}
func isDuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
for _, rec := range records {
if v, ok := rec.(*dns.CNAME); ok {
if v.Target == r.Target {
return true
}
}
}
return false
}
func copyState(state middleware.State, target string, typ uint16) middleware.State {
state1 := middleware.State{W: state.W, Req: state.Req.Copy()}
state1.Req.Question[0] = dns.Question{Name: dns.Fqdn(target), Qtype: dns.ClassINET, Qclass: typ}
return state1
}
// extractIP turns a standard PTR reverse record lookup name
// into an IP address
func extractIP(reverseName string) (string, bool) {
if !strings.HasSuffix(reverseName, arpaSuffix) {
return "", false
}
search := strings.TrimSuffix(reverseName, arpaSuffix)
// reverse the segments and then combine them
segments := reverseArray(strings.Split(search, "."))
return strings.Join(segments, "."), true
}
func reverseArray(arr []string) []string {
for i := 0; i < len(arr)/2; i++ {
j := len(arr) - i - 1
arr[i], arr[j] = arr[j], arr[i]
}
return arr
}

View file

@ -68,7 +68,7 @@ func kubernetesParse(c *caddy.Controller) (Kubernetes, error) {
}
k8s.Zones = NormalizeZoneList(zones)
middleware.Zones(k8s.Zones).FullyQualify()
middleware.Zones(k8s.Zones).Normalize()
if k8s.Zones == nil || len(k8s.Zones) < 1 {
err = errors.New("Zone name must be provided for kubernetes middleware.")

View file

@ -4,6 +4,7 @@ import (
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -56,7 +57,7 @@ func TestLoadBalance(t *testing.T) {
},
}
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
for i, test := range tests {
req := new(dns.Msg)
@ -71,7 +72,7 @@ func TestLoadBalance(t *testing.T) {
}
cname := 0
for _, r := range rec.Msg().Answer {
for _, r := range rec.Msg.Answer {
if r.Header().Rrtype != dns.TypeCNAME {
break
}
@ -81,7 +82,7 @@ func TestLoadBalance(t *testing.T) {
t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname)
}
cname = 0
for _, r := range rec.Msg().Extra {
for _, r := range rec.Msg.Extra {
if r.Header().Rrtype != dns.TypeCNAME {
break
}

View file

@ -7,6 +7,11 @@ import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/metrics"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/pkg/rcode"
"github.com/miekg/coredns/middleware/pkg/replacer"
"github.com/miekg/coredns/middleware/pkg/roller"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
@ -20,32 +25,30 @@ type Logger struct {
}
func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
for _, rule := range l.Rules {
if middleware.Name(rule.NameScope).Matches(state.Name()) {
responseRecorder := middleware.NewResponseRecorder(w)
rcode, err := l.Next.ServeDNS(ctx, responseRecorder, r)
responseRecorder := dnsrecorder.New(w)
rc, err := l.Next.ServeDNS(ctx, responseRecorder, r)
if rcode > 0 {
if rc > 0 {
// There was an error up the chain, but no response has been written yet.
// The error must be handled here so the log entry will record the response size.
if l.ErrorFunc != nil {
l.ErrorFunc(responseRecorder, r, rcode)
l.ErrorFunc(responseRecorder, r, rc)
} else {
rc := middleware.RcodeToString(rcode)
answer := new(dns.Msg)
answer.SetRcode(r, rcode)
answer.SetRcode(r, rc)
state.SizeAndDo(answer)
metrics.Report(state, metrics.Dropped, rc, answer.Len(), time.Now())
metrics.Report(state, metrics.Dropped, rcode.ToString(rc), answer.Len(), time.Now())
w.WriteMsg(answer)
}
rcode = 0
rc = 0
}
rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue)
rep := replacer.New(r, responseRecorder, CommonLogEmptyValue)
rule.Log.Println(rep.Replace(rule.Format))
return rcode, err
return rc, err
}
}
@ -58,7 +61,7 @@ type Rule struct {
OutputFile string
Format string
Log *log.Logger
Roller *middleware.LogRoller
Roller *roller.LogRoller
}
const (

View file

@ -6,7 +6,7 @@ import (
"strings"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -37,7 +37,7 @@ func TestLoggedStatus(t *testing.T) {
r := new(dns.Msg)
r.SetQuestion("example.org.", dns.TypeA)
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
rcode, _ := logger.ServeDNS(ctx, rec, r)
if rcode != 0 {

View file

@ -6,7 +6,7 @@ import (
"os"
"github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/roller"
"github.com/hashicorp/go-syslog"
"github.com/mholt/caddy"
@ -75,13 +75,13 @@ func logParse(c *caddy.Controller) ([]Rule, error) {
for c.Next() {
args := c.RemainingArgs()
var logRoller *middleware.LogRoller
var logRoller *roller.LogRoller
if c.NextBlock() {
if c.Val() == "rotate" {
if c.NextArg() {
if c.Val() == "{" {
var err error
logRoller, err = middleware.ParseRoller(c)
logRoller, err = roller.Parse(c)
if err != nil {
return nil, err
}

View file

@ -3,7 +3,7 @@ package log
import (
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/roller"
"github.com/mholt/caddy"
)
@ -68,7 +68,7 @@ func TestLogParse(t *testing.T) {
NameScope: ".",
OutputFile: "access.log",
Format: DefaultLogFormat,
Roller: &middleware.LogRoller{
Roller: &roller.LogRoller{
MaxSize: 2,
MaxAge: 10,
MaxBackups: 3,

View file

@ -4,13 +4,16 @@ import (
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/pkg/rcode"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
func (m Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
state := request.Request{W: w, Req: r}
qname := state.QName()
zone := middleware.Zones(m.ZoneNames).Matches(qname)
@ -19,35 +22,35 @@ func (m Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
}
// Record response to get status code and size of the reply.
rw := middleware.NewResponseRecorder(w)
rw := dnsrecorder.New(w)
status, err := m.Next.ServeDNS(ctx, rw, r)
Report(state, zone, rw.Rcode(), rw.Size(), rw.Start())
Report(state, zone, rcode.ToString(rw.Rcode), rw.Size, rw.Start)
return status, err
}
// Report is a plain reporting function that the server can use for REFUSED and other
// queries that are turned down because they don't match any middleware.
func Report(state middleware.State, zone, rcode string, size int, start time.Time) {
func Report(req request.Request, zone, rcode string, size int, start time.Time) {
if requestCount == nil {
// no metrics are enabled
return
}
// Proto and Family
net := state.Proto()
net := req.Proto()
fam := "1"
if state.Family() == 2 {
if req.Family() == 2 {
fam = "2"
}
typ := state.QType()
typ := req.QType()
requestCount.WithLabelValues(zone, net, fam).Inc()
requestDuration.WithLabelValues(zone).Observe(float64(time.Since(start) / time.Millisecond))
if state.Do() {
if req.Do() {
requestDo.WithLabelValues(zone).Inc()
}
@ -59,10 +62,10 @@ func Report(state middleware.State, zone, rcode string, size int, start time.Tim
if typ == dns.TypeIXFR || typ == dns.TypeAXFR {
responseTransferSize.WithLabelValues(zone, net).Observe(float64(size))
requestTransferSize.WithLabelValues(zone, net).Observe(float64(state.Size()))
requestTransferSize.WithLabelValues(zone, net).Observe(float64(req.Size()))
} else {
responseSize.WithLabelValues(zone, net).Observe(float64(size))
requestSize.WithLabelValues(zone, net).Observe(float64(state.Size()))
requestSize.WithLabelValues(zone, net).Observe(float64(req.Size()))
}
responseRcode.WithLabelValues(zone, rcode).Inc()

View file

@ -1,25 +0,0 @@
package middleware
import (
"strings"
"github.com/miekg/dns"
)
// Name represents a domain name.
type Name string
// Matches checks to see if other is a subdomain (or the same domain) of n.
// This method assures that names can be easily and consistently matched.
func (n Name) Matches(child string) bool {
if dns.Name(n) == dns.Name(child) {
return true
}
return dns.IsSubDomain(string(n), child)
}
// Normalize lowercases and makes n fully qualified.
func (n Name) Normalize() string {
return strings.ToLower(dns.Fqdn(string(n)))
}

78
middleware/normalize.go Normal file
View file

@ -0,0 +1,78 @@
package middleware
import (
"net"
"strings"
"github.com/miekg/dns"
)
type Zones []string
// Matches checks is qname is a subdomain of any of the zones in z. The match
// will return the most specific zones that matches other. The empty string
// signals a not found condition.
func (z Zones) Matches(qname string) string {
zone := ""
for _, zname := range z {
if dns.IsSubDomain(zname, qname) {
// TODO(miek): hmm, add test for this case
if len(zname) > len(zone) {
zone = zname
}
}
}
return zone
}
// Normalize fully qualifies all zones in z.
func (z Zones) Normalize() {
for i, _ := range z {
z[i] = Name(z[i]).Normalize()
}
}
// Name represents a domain name.
type Name string
// Matches checks to see if other is a subdomain (or the same domain) of n.
// This method assures that names can be easily and consistently matched.
func (n Name) Matches(child string) bool {
if dns.Name(n) == dns.Name(child) {
return true
}
return dns.IsSubDomain(string(n), child)
}
// Normalize lowercases and makes n fully qualified.
func (n Name) Normalize() string { return strings.ToLower(dns.Fqdn(string(n))) }
// Host represents a host from the Corefile, may contain port.
type (
Host string
Addr string
)
// Normalize will return the host portion of host, stripping
// of any port. The host will also be fully qualified and lowercased.
func (h Host) Normalize() string {
// separate host and port
host, _, err := net.SplitHostPort(string(h))
if err != nil {
host, _, _ = net.SplitHostPort(string(h) + ":")
}
return Name(host).Normalize()
}
// Normalize will return a normalized address, if not port is specified
// port 53 is added, otherwise the port will be left as is.
func (a Addr) Normalize() string {
// separate host and port
addr, port, err := net.SplitHostPort(string(a))
if err != nil {
addr, port, _ = net.SplitHostPort(string(a) + ":53")
}
// TODO(miek): lowercase it?
return net.JoinHostPort(addr, port)
}

View file

@ -0,0 +1,57 @@
package dnsrecorder
import (
"time"
"github.com/miekg/dns"
)
// Recorder is a type of ResponseWriter that captures
// the rcode code written to it and also the size of the message
// written in the response. A rcode code does not have
// to be written, however, in which case 0 must be assumed.
// It is best to have the constructor initialize this type
// with that default status code.
type Recorder struct {
dns.ResponseWriter
Rcode int
Size int
Msg *dns.Msg
Start time.Time
}
// New makes and returns a new Recorder,
// which captures the DNS rcode from the ResponseWriter
// and also the length of the response message written through it.
func New(w dns.ResponseWriter) *Recorder {
return &Recorder{
ResponseWriter: w,
Rcode: 0,
Msg: nil,
Start: time.Now(),
}
}
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *Recorder) WriteMsg(res *dns.Msg) error {
r.Rcode = res.Rcode
// We may get called multiple times (axfr for instance).
// Save the last message, but add the sizes.
r.Size += res.Len()
r.Msg = res
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the size of the message that gets written.
func (r *Recorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.Size += n
}
return n, err
}
// Hijack implements dns.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (r *Recorder) Hijack() { r.ResponseWriter.Hijack(); return }

View file

@ -1,4 +1,4 @@
package middleware
package dnsrecorder
/*
func TestNewResponseRecorder(t *testing.T) {

View file

@ -0,0 +1,15 @@
package dnsutil
import "github.com/miekg/dns"
// DuplicateCNAME returns true if r already exists in records.
func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool {
for _, rec := range records {
if v, ok := rec.(*dns.CNAME); ok {
if v.Target == r.Target {
return true
}
}
}
return false
}

View file

@ -0,0 +1,40 @@
package dnsutil
import "strings"
// ExtractAddressFromReverse turns a standard PTR reverse record name
// into an IP address. This works for ipv4 or ipv6.
//
// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion
// failes the empty string is returned.
func ExtractAddressFromReverse(reverseName string) string {
search := ""
switch {
case strings.HasSuffix(reverseName, v4arpaSuffix):
search = strings.TrimSuffix(reverseName, v4arpaSuffix)
case strings.HasSuffix(reverseName, v6arpaSuffix):
search = strings.TrimSuffix(reverseName, v6arpaSuffix)
default:
return ""
}
// Reverse the segments and then combine them.
segments := reverse(strings.Split(search, "."))
return strings.Join(segments, ".")
}
func reverse(slice []string) []string {
for i := 0; i < len(slice)/2; i++ {
j := len(slice) - i - 1
slice[i], slice[j] = slice[j], slice[i]
}
return slice
}
const (
// v4arpaSuffix is the reverse tree suffix for v4 IP addresses.
v4arpaSuffix = ".in-addr.arpa."
// v6arpaSuffix is the reverse tree suffix for v6 IP addresses.
v6arpaSuffix = ".ip6.arpa."
)

View file

@ -1,4 +1,4 @@
package middleware
package edns
import (
"errors"
@ -6,11 +6,11 @@ import (
"github.com/miekg/dns"
)
// Edns0Version checks the EDNS version in the request. If error
// Version checks the EDNS version in the request. If error
// is nil everything is OK and we can invoke the middleware. If non-nil, the
// returned Msg is valid to be returned to the client (and should). For some
// reason this response should not contain a question RR in the question section.
func Edns0Version(req *dns.Msg) (*dns.Msg, error) {
func Version(req *dns.Msg) (*dns.Msg, error) {
opt := req.IsEdns0()
if opt == nil {
return nil, nil
@ -33,8 +33,8 @@ func Edns0Version(req *dns.Msg) (*dns.Msg, error) {
return m, errors.New("EDNS0 BADVERS")
}
// edns0Size returns a normalized size based on proto.
func edns0Size(proto string, size int) int {
// Size returns a normalized size based on proto.
func Size(proto string, size int) int {
if proto == "tcp" {
return dns.MaxMsgSize
}

View file

@ -1,4 +1,4 @@
package middleware
package edns
import (
"testing"
@ -6,21 +6,21 @@ import (
"github.com/miekg/dns"
)
func TestEdns0Version(t *testing.T) {
func TestVersion(t *testing.T) {
m := ednsMsg()
m.Extra[0].(*dns.OPT).SetVersion(2)
_, err := Edns0Version(m)
_, err := Version(m)
if err == nil {
t.Errorf("expected wrong version, but got OK")
}
}
func TestEdns0VersionNoEdns(t *testing.T) {
func TestVersionNoEdns(t *testing.T) {
m := ednsMsg()
m.Extra = nil
_, err := Edns0Version(m)
_, err := Version(m)
if err != nil {
t.Errorf("expected no error, but got one: %s", err)
}

View file

@ -1,4 +1,4 @@
package middleware
package rcode
import (
"strconv"
@ -6,7 +6,7 @@ import (
"github.com/miekg/dns"
)
func RcodeToString(rcode int) string {
func ToString(rcode int) string {
if str, ok := dns.RcodeToString[rcode]; ok {
return str
}

View file

@ -1,10 +1,13 @@
package middleware
package replacer
import (
"strconv"
"strings"
"time"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -22,46 +25,43 @@ type replacer struct {
emptyValue string
}
// NewReplacer makes a new replacer based on r and rr.
// New makes a new replacer based on r and rr.
// Do not create a new replacer until r and rr have all
// the needed values, because this function copies those
// values into the replacer. rr may be nil if it is not
// available. emptyValue should be the string that is used
// in place of empty string (can still be empty string).
func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer {
state := State{W: rr, Req: r}
func New(r *dns.Msg, rr *dnsrecorder.Recorder, emptyValue string) Replacer {
req := request.Request{W: rr, Req: r}
rep := replacer{
replacements: map[string]string{
"{type}": state.Type(),
"{name}": state.Name(),
"{class}": state.Class(),
"{proto}": state.Proto(),
"{type}": req.Type(),
"{name}": req.Name(),
"{class}": req.Class(),
"{proto}": req.Proto(),
"{when}": func() string {
return time.Now().Format(timeFormat)
}(),
"{remote}": state.IP(),
"{port}": func() string {
p, _ := state.Port()
return p
}(),
"{remote}": req.IP(),
"{port}": req.Port(),
},
emptyValue: emptyValue,
}
if rr != nil {
rcode := dns.RcodeToString[rr.rcode]
rcode := dns.RcodeToString[rr.Rcode]
if rcode == "" {
rcode = strconv.Itoa(rr.rcode)
rcode = strconv.Itoa(rr.Rcode)
}
rep.replacements["{rcode}"] = rcode
rep.replacements["{size}"] = strconv.Itoa(rr.size)
rep.replacements["{duration}"] = time.Since(rr.start).String()
rep.replacements["{size}"] = strconv.Itoa(rr.Size)
rep.replacements["{duration}"] = time.Since(rr.Start).String()
}
// Header placeholders (case-insensitive)
rep.replacements[headerReplacer+"id}"] = strconv.Itoa(int(r.Id))
rep.replacements[headerReplacer+"opcode}"] = strconv.Itoa(int(r.Opcode))
rep.replacements[headerReplacer+"do}"] = boolToString(state.Do())
rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(state.Size())
rep.replacements[headerReplacer+"do}"] = boolToString(req.Do())
rep.replacements[headerReplacer+"bufsize}"] = strconv.Itoa(req.Size())
return rep
}

View file

@ -1,4 +1,4 @@
package middleware
package replacer
/*
func TestNewReplacer(t *testing.T) {

View file

@ -1,19 +1,19 @@
package middleware
package response
import "github.com/miekg/dns"
type MsgType int
type Type int
const (
Success MsgType = iota
NameError // NXDOMAIN in header, SOA in auth.
NoData // NOERROR in header, SOA in auth.
Delegation // NOERROR in header, NS in auth, optionally fluff in additional (not checked).
OtherError // Don't cache these.
Success Type = iota
NameError // NXDOMAIN in header, SOA in auth.
NoData // NOERROR in header, SOA in auth.
Delegation // NOERROR in header, NS in auth, optionally fluff in additional (not checked).
OtherError // Don't cache these.
)
// Classify classifies a message, it returns the MessageType.
func Classify(m *dns.Msg) (MsgType, *dns.OPT) {
// Classify classifies a message, it returns the Type.
func Classify(m *dns.Msg) (Type, *dns.OPT) {
opt := m.IsEdns0()
if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess {

View file

@ -1,4 +1,4 @@
package middleware
package response
import (
"testing"

View file

@ -1,4 +1,4 @@
package middleware
package roller
import (
"io"
@ -8,7 +8,7 @@ import (
"gopkg.in/natefinch/lumberjack.v2"
)
func ParseRoller(c *caddy.Controller) (*LogRoller, error) {
func Parse(c *caddy.Controller) (*LogRoller, error) {
var size, age, keep int
// This is kind of a hack to support nested blocks:
// As we are already in a block: either log or errors,

View file

@ -1,25 +1,39 @@
package middleware
package storage
import (
"net/http"
"os"
"path"
"path/filepath"
"runtime"
)
// dir wraps http.Dir that restrict file access to a specific directory tree.
// dir wraps an http.Dir that restrict file access to a specific directory tree, see http.Dir's documentation
// for methods for accessing files.
type dir http.Dir
// CoreDir is the directory where middleware can store assets, like zone files after a zone transfer
// or public and private keys or anything else a middleware might need. The convention is to place
// assets in a subdirectory named after the fully qualified zone.
// assets in a subdirectory named after the zone prefixed with "D", to prevent the root zone become a hidden directory.
//
// example.org./Kexample<something>.key
// Dexample.org/Kexample.org<something>.key
//
// Note that subzone(s) under example.org are places in the own directory under CoreDir:
//
// Dexample.org/...
// Db.example.org/...
//
// CoreDir will default to "$HOME/.coredns" on Unix, but it's location can be overriden with the COREDNSPATH
// environment variable.
var CoreDir dir = dir(fsPath())
func (d dir) Zone(z string) dir {
if z != "." && z[len(z)-2] == '.' {
return dir(path.Join(string(d), "D"+z[:len(z)-1]))
}
return dir(path.Join(string(d), "D"+z))
}
// fsPath returns the path to the directory where the application may store data.
// If COREDNSPATH env variable. is set, that value is used. Otherwise, the path is
// the result of evaluating "$HOME/.coredns".

View file

@ -0,0 +1,42 @@
package storage
import (
"os"
"path"
"strings"
"testing"
)
func TestfsPath(t *testing.T) {
if actual := fsPath(); !strings.HasSuffix(actual, ".coredns") {
t.Errorf("Expected path to be a .coredns folder, got: %v", actual)
}
os.Setenv("COREDNSPATH", "testpath")
defer os.Setenv("COREDNSPATH", "")
if actual, expected := fsPath(), "testpath"; actual != expected {
t.Errorf("Expected path to be %v, got: %v", expected, actual)
}
}
func TestZone(t *testing.T) {
for _, ts := range []string{"example.org.", "example.org"} {
d := CoreDir.Zone(ts)
actual := path.Base(string(d))
expected := "D" + ts
if actual != expected {
t.Errorf("Expected path to be %v, got %v", actual, expected)
}
}
}
func TestZoneRoot(t *testing.T) {
for _, ts := range []string{"."} {
d := CoreDir.Zone(ts)
actual := path.Base(string(d))
expected := "D" + ts
if actual != expected {
t.Errorf("Expected path to be %v, got %v", actual, expected)
}
}
}

View file

@ -7,7 +7,8 @@ import (
"sync/atomic"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -54,18 +55,19 @@ func New(hosts []string) Proxy {
// Lookup will use name and type to forge a new message and will send that upstream. It will
// set any EDNS0 options correctly so that downstream will be able to process the reply.
// Lookup is not suitable for forwarding request. Ssee for that.
func (p Proxy) Lookup(state middleware.State, name string, tpe uint16) (*dns.Msg, error) {
func (p Proxy) Lookup(state request.Request, name string, tpe uint16) (*dns.Msg, error) {
req := new(dns.Msg)
req.SetQuestion(name, tpe)
state.SizeAndDo(req)
return p.lookup(state, req)
}
func (p Proxy) Forward(state middleware.State) (*dns.Msg, error) {
func (p Proxy) Forward(state request.Request) (*dns.Msg, error) {
return p.lookup(state, state.Req)
}
func (p Proxy) lookup(state middleware.State, r *dns.Msg) (*dns.Msg, error) {
func (p Proxy) lookup(state request.Request, r *dns.Msg) (*dns.Msg, error) {
var (
reply *dns.Msg
err error
@ -84,9 +86,9 @@ func (p Proxy) lookup(state middleware.State, r *dns.Msg) (*dns.Msg, error) {
atomic.AddInt64(&host.Conns, 1)
if state.Proto() == "tcp" {
reply, err = middleware.Exchange(p.Client.TCP, r, host.Name)
reply, _, err = p.Client.TCP.Exchange(r, host.Name)
} else {
reply, err = middleware.Exchange(p.Client.UDP, r, host.Name)
reply, _, err = p.Client.UDP.Exchange(r, host.Name)
}
atomic.AddInt64(&host.Conns, -1)

View file

@ -2,7 +2,7 @@
package proxy
import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -20,10 +20,10 @@ func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR)
)
switch {
case middleware.Proto(w) == "tcp":
reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
case request.Proto(w) == "tcp": // TODO(miek): keep this in request
reply, _, err = p.Client.TCP.Exchange(r, p.Host)
default:
reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
reply, _, err = p.Client.UDP.Exchange(r, p.Host)
}
if reply != nil && reply.Truncated {

View file

@ -1,72 +0,0 @@
package middleware
import (
"time"
"github.com/miekg/dns"
)
// ResponseRecorder is a type of ResponseWriter that captures
// the rcode code written to it and also the size of the message
// written in the response. A rcode code does not have
// to be written, however, in which case 0 must be assumed.
// It is best to have the constructor initialize this type
// with that default status code.
type ResponseRecorder struct {
dns.ResponseWriter
rcode int
size int
msg *dns.Msg
start time.Time
}
// NewResponseRecorder makes and returns a new responseRecorder,
// which captures the DNS rcode from the ResponseWriter
// and also the length of the response message written through it.
func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder {
return &ResponseRecorder{
ResponseWriter: w,
rcode: 0,
msg: nil,
start: time.Now(),
}
}
// WriteMsg records the status code and calls the
// underlying ResponseWriter's WriteMsg method.
func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error {
r.rcode = res.Rcode
// We may get called multiple times (axfr for instance).
// Save the last message, but add the sizes.
r.size += res.Len()
r.msg = res
return r.ResponseWriter.WriteMsg(res)
}
// Write is a wrapper that records the size of the message that gets written.
func (r *ResponseRecorder) Write(buf []byte) (int, error) {
n, err := r.ResponseWriter.Write(buf)
if err == nil {
r.size += n
}
return n, err
}
// Size returns the size.
func (r *ResponseRecorder) Size() int { return r.size }
// Rcode returns the rcode.
func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) }
// Start returns the start time of the ResponseRecorder.
func (r *ResponseRecorder) Start() time.Time { return r.start }
// Msg returns the written message from the ResponseRecorder.
func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg }
// Hijack implements dns.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (r *ResponseRecorder) Hijack() {
r.ResponseWriter.Hijack()
return
}

View file

@ -5,7 +5,8 @@ import (
"regexp"
"strings"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/replacer"
"github.com/miekg/dns"
)
@ -25,8 +26,8 @@ func operatorError(operator string) error {
return fmt.Errorf("Invalid operator %v", operator)
}
func newReplacer(r *dns.Msg) middleware.Replacer {
return middleware.NewReplacer(r, nil, "")
func newReplacer(r *dns.Msg) replacer.Replacer {
return replacer.New(r, nil, "")
}
// condition is a rewrite condition.

View file

@ -117,143 +117,3 @@ func (s SimpleRule) Rewrite(r *dns.Msg) Result {
}
return RewriteIgnored
}
/*
// ComplexRule is a rewrite rule based on a regular expression
type ComplexRule struct {
// Path base. Request to this path and subpaths will be rewritten
Base string
// Path to rewrite to
To string
// If set, neither performs rewrite nor proceeds
// with request. Only returns code.
Status int
// Extensions to filter by
Exts []string
// Rewrite conditions
Ifs []If
*regexp.Regexp
}
// NewComplexRule creates a new RegexpRule. It returns an error if regexp
// pattern (pattern) or extensions (ext) are invalid.
func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) {
// validate regexp if present
var r *regexp.Regexp
if pattern != "" {
var err error
r, err = regexp.Compile(pattern)
if err != nil {
return nil, err
}
}
// validate extensions if present
for _, v := range ext {
if len(v) < 2 || (len(v) < 3 && v[0] == '!') {
// check if no extension is specified
if v != "/" && v != "!/" {
return nil, fmt.Errorf("invalid extension %v", v)
}
}
}
return &ComplexRule{
Base: base,
To: to,
Status: status,
Exts: ext,
Ifs: ifs,
Regexp: r,
}, nil
}
// Rewrite rewrites the internal location of the current request.
func (r *ComplexRule) Rewrite(req *dns.Msg) (re Result) {
rPath := req.URL.Path
replacer := newReplacer(req)
// validate base
if !middleware.Path(rPath).Matches(r.Base) {
return
}
// validate extensions
if !r.matchExt(rPath) {
return
}
// validate regexp if present
if r.Regexp != nil {
// include trailing slash in regexp if present
start := len(r.Base)
if strings.HasSuffix(r.Base, "/") {
start--
}
matches := r.FindStringSubmatch(rPath[start:])
switch len(matches) {
case 0:
// no match
return
default:
// set regexp match variables {1}, {2} ...
for i := 1; i < len(matches); i++ {
replacer.Set(fmt.Sprint(i), matches[i])
}
}
}
// validate rewrite conditions
for _, i := range r.Ifs {
if !i.True(req) {
return
}
}
// if status is present, stop rewrite and return it.
if r.Status != 0 {
return RewriteStatus
}
// attempt rewrite
return To(fs, req, r.To, replacer)
}
// matchExt matches rPath against registered file extensions.
// Returns true if a match is found and false otherwise.
func (r *ComplexRule) matchExt(rPath string) bool {
f := filepath.Base(rPath)
ext := path.Ext(f)
if ext == "" {
ext = "/"
}
mustUse := false
for _, v := range r.Exts {
use := true
if v[0] == '!' {
use = false
v = v[1:]
}
if use {
mustUse = true
}
if ext == v {
return use
}
}
if mustUse {
return false
}
return true
}
*/

View file

@ -4,6 +4,7 @@ import (
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/pkg/dnsrecorder"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
@ -50,10 +51,10 @@ func TestRewrite(t *testing.T) {
m.SetQuestion(tc.from, tc.fromT)
m.Question[0].Qclass = tc.fromC
rec := middleware.NewResponseRecorder(&test.ResponseWriter{})
rec := dnsrecorder.New(&test.ResponseWriter{})
rw.ServeDNS(ctx, rec, m)
resp := rec.Msg()
resp := rec.Msg
if resp.Question[0].Name != tc.to {
t.Errorf("Test %d: Expected Name to be '%s' but was '%s'", i, tc.to, resp.Question[0].Name)
}

View file

@ -1,289 +0,0 @@
package middleware
import (
"testing"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
)
func TestStateDo(t *testing.T) {
st := testState()
st.Do()
if st.do == 0 {
t.Fatalf("expected st.do to be set")
}
}
func TestStateRemote(t *testing.T) {
st := testState()
if st.IP() != "10.240.0.1" {
t.Fatalf("wrong IP from state")
}
p, err := st.Port()
if err != nil {
t.Fatalf("failed to get Port from state")
}
if p != "40212" {
t.Fatalf("wrong port from state")
}
}
func BenchmarkStateDo(b *testing.B) {
st := testState()
for i := 0; i < b.N; i++ {
st.Do()
}
}
func BenchmarkStateSize(b *testing.B) {
st := testState()
for i := 0; i < b.N; i++ {
st.Size()
}
}
func testState() State {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.SetEdns0(4097, true)
return State{W: &test.ResponseWriter{}, Req: m}
}
/*
func TestHeader(t *testing.T) {
state := getContextOrFail(t)
headerKey, headerVal := "Header1", "HeaderVal1"
state.Req.Header.Add(headerKey, headerVal)
actualHeaderVal := state.Header(headerKey)
if actualHeaderVal != headerVal {
t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
}
missingHeaderVal := state.Header("not-existing")
if missingHeaderVal != "" {
t.Errorf("Expected empty header value, found %s", missingHeaderVal)
}
}
func TestIP(t *testing.T) {
state := getContextOrFail(t)
tests := []struct {
inputRemoteAddr string
expectedIP string
}{
// Test 0 - ipv4 with port
{"1.1.1.1:1111", "1.1.1.1"},
// Test 1 - ipv4 without port
{"1.1.1.1", "1.1.1.1"},
// Test 2 - ipv6 with port
{"[::1]:11", "::1"},
// Test 3 - ipv6 without port and brackets
{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
// Test 4 - ipv6 with zone and port
{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
state.Req.RemoteAddr = test.inputRemoteAddr
actualIP := state.IP()
if actualIP != test.expectedIP {
t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
}
}
}
func TestURL(t *testing.T) {
state := getContextOrFail(t)
inputURL := "http://localhost"
state.Req.RequestURI = inputURL
if inputURL != state.URI() {
t.Errorf("Expected url %s, found %s", inputURL, state.URI())
}
}
func TestHost(t *testing.T) {
tests := []struct {
input string
expectedHost string
shouldErr bool
}{
{
input: "localhost:123",
expectedHost: "localhost",
shouldErr: false,
},
{
input: "localhost",
expectedHost: "localhost",
shouldErr: false,
},
{
input: "[::]",
expectedHost: "",
shouldErr: true,
},
}
for _, test := range tests {
testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
}
}
func TestPort(t *testing.T) {
tests := []struct {
input string
expectedPort string
shouldErr bool
}{
{
input: "localhost:123",
expectedPort: "123",
shouldErr: false,
},
{
input: "localhost",
expectedPort: "80", // assuming 80 is the default port
shouldErr: false,
},
{
input: ":8080",
expectedPort: "8080",
shouldErr: false,
},
{
input: "[::]",
expectedPort: "",
shouldErr: true,
},
}
for _, test := range tests {
testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
}
}
func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
state := getContextOrFail(t)
state.Req.Host = input
var actualResult, testedObject string
var err error
if isTestingHost {
actualResult, err = state.Host()
testedObject = "host"
} else {
actualResult, err = state.Port()
testedObject = "port"
}
if shouldErr && err == nil {
t.Errorf("Expected error, found nil!")
return
}
if !shouldErr && err != nil {
t.Errorf("Expected no error, found %s", err)
return
}
if actualResult != expectedResult {
t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
}
}
func TestPathMatches(t *testing.T) {
state := getContextOrFail(t)
tests := []struct {
urlStr string
pattern string
shouldMatch bool
}{
// Test 0
{
urlStr: "http://localhost/",
pattern: "",
shouldMatch: true,
},
// Test 1
{
urlStr: "http://localhost",
pattern: "",
shouldMatch: true,
},
// Test 1
{
urlStr: "http://localhost/",
pattern: "/",
shouldMatch: true,
},
// Test 3
{
urlStr: "http://localhost/?param=val",
pattern: "/",
shouldMatch: true,
},
// Test 4
{
urlStr: "http://localhost/dir1/dir2",
pattern: "/dir2",
shouldMatch: false,
},
// Test 5
{
urlStr: "http://localhost/dir1/dir2",
pattern: "/dir1",
shouldMatch: true,
},
// Test 6
{
urlStr: "http://localhost:444/dir1/dir2",
pattern: "/dir1",
shouldMatch: true,
},
}
for i, test := range tests {
testPrefix := getTestPrefix(i)
var err error
state.Req.URL, err = url.Parse(test.urlStr)
if err != nil {
t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
}
matches := state.PathMatches(test.pattern)
if matches != test.shouldMatch {
t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
}
}
}
func initTestContext() (Context, error) {
body := bytes.NewBufferString("request body")
request, err := http.NewRequest("GET", "https://localhost", body)
if err != nil {
return Context{}, err
}
return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
}
func getTestPrefix(testN int) string {
return fmt.Sprintf("Test [%d]: ", testN)
}
*/

View file

@ -1,27 +0,0 @@
package middleware
import "github.com/miekg/dns"
type Zones []string
// Matches checks is qname is a subdomain of any of the zones in z. The match
// will return the most specific zones that matches other. The empty string
// signals a not found condition.
func (z Zones) Matches(qname string) string {
zone := ""
for _, zname := range z {
if dns.IsSubDomain(zname, qname) {
if len(zname) > len(zone) {
zone = zname
}
}
}
return zone
}
// FullyQualify fully qualifies all zones in z.
func (z Zones) FullyQualify() {
for i, _ := range z {
z[i] = dns.Fqdn(z[i])
}
}

View file

@ -1,17 +1,16 @@
package middleware
package request
import (
"net"
"strings"
"time"
"github.com/miekg/coredns/middleware/pkg/edns"
"github.com/miekg/dns"
)
// This file contains the state functions available for use in the middlewares.
// State contains some connection state and is useful in middleware.
type State struct {
// Request contains some connection state and is useful in middleware.
type Request struct {
Req *dns.Msg
W dns.ResponseWriter
@ -24,38 +23,41 @@ type State struct {
name string
}
// Now returns the current timestamp in the specified format.
func (s *State) Now(format string) string { return time.Now().Format(format) }
// NowDate returns the current date/time that can be used in other time functions.
func (s *State) NowDate() time.Time { return time.Now() }
// NewWithQuestion returns a new request based on the old, but with a new question
// section in the request.
func (r *Request) NewWithQuestion(name string, typ uint16) Request {
req1 := Request{W: r.W, Req: r.Req.Copy()}
req1.Req.Question[0] = dns.Question{Name: dns.Fqdn(name), Qtype: dns.ClassINET, Qclass: typ}
return req1
}
// IP gets the (remote) IP address of the client making the request.
func (s *State) IP() string {
ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String())
func (r *Request) IP() string {
ip, _, err := net.SplitHostPort(r.W.RemoteAddr().String())
if err != nil {
return s.W.RemoteAddr().String()
return r.W.RemoteAddr().String()
}
return ip
}
// Post gets the (remote) Port of the client making the request.
func (s *State) Port() (string, error) {
_, port, err := net.SplitHostPort(s.W.RemoteAddr().String())
func (r *Request) Port() string {
_, port, err := net.SplitHostPort(r.W.RemoteAddr().String())
if err != nil {
return "0", err
return "0"
}
return port, nil
return port
}
// RemoteAddr returns the net.Addr of the client that sent the current request.
func (s *State) RemoteAddr() string {
return s.W.RemoteAddr().String()
func (r *Request) RemoteAddr() string {
return r.W.RemoteAddr().String()
}
// Proto gets the protocol used as the transport. This will be udp or tcp.
func (s *State) Proto() string { return Proto(s.W) }
func (r *Request) Proto() string { return Proto(r.W) }
// FIXME(miek): why not a method on Request
// Proto gets the protocol used as the transport. This will be udp or tcp.
func Proto(w dns.ResponseWriter) string {
if _, ok := w.RemoteAddr().(*net.UDPAddr); ok {
@ -67,11 +69,10 @@ func Proto(w dns.ResponseWriter) string {
return "udp"
}
// Family returns the family of the transport.
// 1 for IPv4 and 2 for IPv6.
func (s *State) Family() int {
// Family returns the family of the transport, 1 for IPv4 and 2 for IPv6.
func (r *Request) Family() int {
var a net.IP
ip := s.W.RemoteAddr()
ip := r.W.RemoteAddr()
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
@ -86,48 +87,49 @@ func (s *State) Family() int {
}
// Do returns if the request has the DO (DNSSEC OK) bit set.
func (s *State) Do() bool {
if s.do != 0 {
return s.do == doTrue
func (r *Request) Do() bool {
if r.do != 0 {
return r.do == doTrue
}
if o := s.Req.IsEdns0(); o != nil {
if o := r.Req.IsEdns0(); o != nil {
if o.Do() {
s.do = doTrue
r.do = doTrue
} else {
s.do = doFalse
r.do = doFalse
}
return o.Do()
}
s.do = doFalse
r.do = doFalse
return false
}
// UDPSize returns if UDP buffer size advertised in the requests OPT record.
// Size returns if UDP buffer size advertised in the requests OPT record.
// Or when the request was over TCP, we return the maximum allowed size of 64K.
func (s *State) Size() int {
if s.size != 0 {
return s.size
func (r *Request) Size() int {
if r.size != 0 {
return r.size
}
size := 0
if o := s.Req.IsEdns0(); o != nil {
if o := r.Req.IsEdns0(); o != nil {
if o.Do() == true {
s.do = doTrue
r.do = doTrue
} else {
s.do = doFalse
r.do = doFalse
}
size = int(o.UDPSize())
}
size = edns0Size(s.Proto(), size)
s.size = size
// TODO(miek) move edns.Size to dnsutil?
size = edns.Size(r.Proto(), size)
r.size = size
return size
}
// SizeAndDo adds an OPT record that the reflects the intent from state.
// SizeAndDo adds an OPT record that the reflects the intent from request.
// The returned bool indicated if an record was found and normalised.
func (s *State) SizeAndDo(m *dns.Msg) bool {
o := s.Req.IsEdns0() // TODO(miek): speed this up
func (r *Request) SizeAndDo(m *dns.Msg) bool {
o := r.Req.IsEdns0() // TODO(miek): speed this up
if o == nil {
return false
}
@ -163,8 +165,8 @@ const (
// the TC bit will be set regardless of protocol, even TCP message will get the bit, the client
// should then retry with pigeons.
// TODO(referral).
func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
size := s.Size()
func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
size := r.Size()
l := reply.Len()
if size >= l {
return reply, ScrubIgnored
@ -173,7 +175,7 @@ func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
// If not delegation, drop additional section.
reply.Extra = nil
s.SizeAndDo(reply)
r.SizeAndDo(reply)
l = reply.Len()
if size >= l {
return reply, ScrubDone
@ -184,42 +186,42 @@ func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) {
}
// Type returns the type of the question as a string.
func (s *State) Type() string { return dns.Type(s.Req.Question[0].Qtype).String() }
func (r *Request) Type() string { return dns.Type(r.Req.Question[0].Qtype).String() }
// QType returns the type of the question as an uint16.
func (s *State) QType() uint16 { return s.Req.Question[0].Qtype }
func (r *Request) QType() uint16 { return r.Req.Question[0].Qtype }
// Name returns the name of the question in the request. Note
// this name will always have a closing dot and will be lower cased. After a call Name
// the value will be cached. To clear this caching call Clear.
func (s *State) Name() string {
if s.name != "" {
return s.name
func (r *Request) Name() string {
if r.name != "" {
return r.name
}
s.name = strings.ToLower(dns.Name(s.Req.Question[0].Name).String())
return s.name
r.name = strings.ToLower(dns.Name(r.Req.Question[0].Name).String())
return r.name
}
// QName returns the name of the question in the request.
func (s *State) QName() string { return dns.Name(s.Req.Question[0].Name).String() }
func (r *Request) QName() string { return dns.Name(r.Req.Question[0].Name).String() }
// Class returns the class of the question in the request.
func (s *State) Class() string { return dns.Class(s.Req.Question[0].Qclass).String() }
func (r *Request) Class() string { return dns.Class(r.Req.Question[0].Qclass).String() }
// QClass returns the class of the question in the request.
func (s *State) QClass() uint16 { return s.Req.Question[0].Qclass }
func (r *Request) QClass() uint16 { return r.Req.Question[0].Qclass }
// ErrorMessage returns an error message suitable for sending
// back to the client.
func (s *State) ErrorMessage(rcode int) *dns.Msg {
func (r *Request) ErrorMessage(rcode int) *dns.Msg {
m := new(dns.Msg)
m.SetRcode(s.Req, rcode)
m.SetRcode(r.Req, rcode)
return m
}
// Clear clears all caching from State s.
func (s *State) Clear() {
s.name = ""
// Clear clears all caching from Request s.
func (r *Request) Clear() {
r.name = ""
}
const (

55
request/request_test.go Normal file
View file

@ -0,0 +1,55 @@
package request
import (
"testing"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/dns"
)
func TestRequestDo(t *testing.T) {
st := testRequest()
st.Do()
if st.do == 0 {
t.Fatalf("Expected st.do to be set")
}
}
func TestRequestRemote(t *testing.T) {
st := testRequest()
if st.IP() != "10.240.0.1" {
t.Fatalf("Wrong IP from request")
}
p := st.Port()
if p == "" {
t.Fatalf("Failed to get Port from request")
}
if p != "40212" {
t.Fatalf("Wrong port from request")
}
}
func BenchmarkRequestDo(b *testing.B) {
st := testRequest()
for i := 0; i < b.N; i++ {
st.Do()
}
}
func BenchmarkRequestSize(b *testing.B) {
st := testRequest()
for i := 0; i < b.N; i++ {
st.Size()
}
}
func testRequest() Request {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.SetEdns0(4097, true)
return Request{W: &test.ResponseWriter{}, Req: m}
}

View file

@ -9,11 +9,11 @@ import (
"testing"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
etcdc "github.com/coreos/etcd/client"
"github.com/miekg/dns"
@ -67,7 +67,7 @@ func TestEtcdStubAndProxyLookup(t *testing.T) {
}
p := proxy.New([]string{udp}) // use udp port from the server
state := middleware.State{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
resp, err := p.Lookup(state, "example.com.", dns.TypeA)
if err != nil {
t.Error("Expected to receive reply, but didn't")

View file

@ -5,9 +5,9 @@ import (
"log"
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/middleware/test"
"github.com/miekg/coredns/request"
"github.com/miekg/dns"
)
@ -46,7 +46,7 @@ func TestLookupProxy(t *testing.T) {
log.SetOutput(ioutil.Discard)
p := proxy.New([]string{udp})
state := middleware.State{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
resp, err := p.Lookup(state, "example.org.", dns.TypeA)
if err != nil {
t.Fatal("Expected to receive reply, but didn't")

View file

@ -17,7 +17,7 @@ func TestProxyToChaosServer(t *testing.T) {
t.Fatalf("could not get CoreDNS serving instance: %s", err)
}
udpChaos, tcpChaos := CoreDNSServerPorts(chaos, 0)
udpChaos, _ := CoreDNSServerPorts(chaos, 0)
defer chaos.Stop()
corefileProxy := `.:0 {
@ -32,24 +32,23 @@ func TestProxyToChaosServer(t *testing.T) {
udp, _ := CoreDNSServerPorts(proxy, 0)
defer proxy.Stop()
chaosTest(t, udpChaos, "udp")
chaosTest(t, tcpChaos, "tcp")
chaosTest(t, udpChaos)
chaosTest(t, udp, "udp")
chaosTest(t, udp)
// chaosTest(t, tcp, "tcp"), commented out because we use the original transport to reach the
// proxy and we only forward to the udp port.
}
func chaosTest(t *testing.T, server, net string) {
func chaosTest(t *testing.T, server string) {
m := Msg("version.bind.", dns.TypeTXT, nil)
m.Question[0].Qclass = dns.ClassCHAOS
r, err := Exchange(m, server, net)
r, err := dns.Exchange(m, server)
if err != nil {
t.Fatalf("Could not send message: %s", err)
}
if r.Rcode != dns.RcodeSuccess || len(r.Answer) == 0 {
t.Fatalf("Expected successful reply on %s, got %s", net, dns.RcodeToString[r.Rcode])
t.Fatalf("Expected successful reply, got %s", dns.RcodeToString[r.Rcode])
}
if r.Answer[0].String() != `version.bind. 0 CH TXT "CoreDNS-001"` {
t.Fatalf("Expected version.bind. reply, got %s", r.Answer[0].String())

View file

@ -1,10 +1,6 @@
package test
import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
import "github.com/miekg/dns"
func Msg(zone string, typ uint16, o *dns.OPT) *dns.Msg {
m := new(dns.Msg)
@ -14,9 +10,3 @@ func Msg(zone string, typ uint16, o *dns.OPT) *dns.Msg {
}
return m
}
func Exchange(m *dns.Msg, server, net string) (*dns.Msg, error) {
c := new(dns.Client)
c.Net = net
return middleware.Exchange(c, m, server)
}