diff --git a/middleware/file/closest.go b/middleware/file/closest.go index df74cad50..ab0cbb00a 100644 --- a/middleware/file/closest.go +++ b/middleware/file/closest.go @@ -3,18 +3,18 @@ package file import "github.com/miekg/dns" // ClosestEncloser returns the closest encloser for rr. -func (z *Zone) ClosestEncloser(rr dns.RR) string { +func (z *Zone) ClosestEncloser(qname string, qtype uint16) string { // tree/tree.go does not store a parent *Node pointer, so we can't // just follow up the tree. TODO(miek): fix. - offset, end := dns.NextLabel(rr.Header().Name, 0) + offset, end := dns.NextLabel(qname, 0) for !end { - elem, _ := z.Tree.Get(rr) + elem, _ := z.Tree.Search(qname, qtype) if elem != nil { return elem.Name() } - rr.Header().Name = rr.Header().Name[offset:] + qname = qname[offset:] - offset, end = dns.NextLabel(rr.Header().Name, offset) + offset, end = dns.NextLabel(qname, offset) } return z.SOA.Header().Name @@ -22,8 +22,8 @@ func (z *Zone) ClosestEncloser(rr dns.RR) string { // nameErrorProof finds the closest encloser and return an NSEC that proofs // the wildcard does not exist and an NSEC that proofs the name does no exist. -func (z *Zone) nameErrorProof(rr dns.RR) []dns.RR { - elem := z.Tree.Prev(rr) +func (z *Zone) nameErrorProof(qname string, qtype uint16) []dns.RR { + elem := z.Tree.Prev(qname) if elem == nil { return nil } @@ -37,10 +37,8 @@ func (z *Zone) nameErrorProof(rr dns.RR) []dns.RR { } // We do this lookup twice, once for wildcard and once for the name proof. TODO(miek): fix - ce := z.ClosestEncloser(rr) - wildcard := "*." + ce - rr.Header().Name = wildcard - elem = z.Tree.Prev(rr) + ce := z.ClosestEncloser(qname, qtype) + elem = z.Tree.Prev("*." + ce) if elem == nil { // Root? return nil diff --git a/middleware/file/closest_test.go b/middleware/file/closest_test.go index db0b718b2..91b65f231 100644 --- a/middleware/file/closest_test.go +++ b/middleware/file/closest_test.go @@ -25,11 +25,8 @@ func TestClosestEncloser(t *testing.T) { {"blaat.a.miek.nl.", "a.miek.nl."}, } - mk, _ := dns.TypeToRR[dns.TypeA] - rr := mk() for _, tc := range tests { - rr.Header().Name = tc.in - ce := z.ClosestEncloser(rr) + ce := z.ClosestEncloser(tc.in, dns.TypeA) if ce != tc.out { t.Errorf("expected ce to be %s for %s, got %s", tc.out, tc.in, ce) } diff --git a/middleware/file/lookup.go b/middleware/file/lookup.go index bc048c66a..ddbfec6f5 100644 --- a/middleware/file/lookup.go +++ b/middleware/file/lookup.go @@ -19,33 +19,21 @@ const ( // Lookup looks up qname and qtype in the zone. When do is true DNSSEC records are included. // Three sets of records are returned, one for the answer, one for authority and one for the additional section. func (z *Zone) Lookup(qname string, qtype uint16, do bool) ([]dns.RR, []dns.RR, []dns.RR, Result) { - var rr dns.RR - mk, known := dns.TypeToRR[qtype] - if !known { - return nil, nil, nil, ServerFailure - } else { - rr = mk() - } if qtype == dns.TypeSOA { return z.lookupSOA(do) } - // Misuse rr to be a question. - rr.Header().Rrtype = qtype - rr.Header().Name = qname - - elem, res := z.Tree.Get(rr) + elem, res := z.Tree.Search(qname, qtype) if elem == nil { if res == tree.EmptyNonTerminal { - return z.emptyNonTerminal(rr, do) + return z.emptyNonTerminal(qname, do) } - return z.nameError(rr, do) + return z.nameError(qname, qtype, do) } rrs := elem.Types(dns.TypeCNAME) if len(rrs) > 0 { // should only ever be 1 actually; TODO(miek) check for this? - rr.Header().Name = rrs[0].(*dns.CNAME).Target - return z.lookupCNAME(rrs, rr, do) + return z.lookupCNAME(rrs, qtype, do) } rrs = elem.Types(qtype) @@ -67,33 +55,29 @@ func (z *Zone) noData(elem *tree.Elem, do bool) ([]dns.RR, []dns.RR, []dns.RR, R return nil, append(soa, nsec...), nil, Success } -func (z *Zone) emptyNonTerminal(rr dns.RR, do bool) ([]dns.RR, []dns.RR, []dns.RR, Result) { +func (z *Zone) emptyNonTerminal(qname string, do bool) ([]dns.RR, []dns.RR, []dns.RR, Result) { soa, _, _, _ := z.lookupSOA(do) - elem := z.Tree.Prev(rr) + elem := z.Tree.Prev(qname) nsec := z.lookupNSEC(elem, do) return nil, append(soa, nsec...), nil, Success } -func (z *Zone) nameError(rr dns.RR, do bool) ([]dns.RR, []dns.RR, []dns.RR, Result) { +func (z *Zone) nameError(qname string, qtype uint16, do bool) ([]dns.RR, []dns.RR, []dns.RR, Result) { // Is there a wildcard? - rr1 := dns.Copy(rr) - rr1.Header().Name = rr.Header().Name - rr1.Header().Rrtype = rr.Header().Rrtype - ce := z.ClosestEncloser(rr1) - rr1.Header().Name = "*." + ce - elem, _ := z.Tree.Get(rr1) // use result here? + ce := z.ClosestEncloser(qname, qtype) + elem, _ := z.Tree.Search("*."+ce, qtype) // use result here? if elem != nil { - ret := elem.Types(rr1.Header().Rrtype) // there can only be one of these (or zero) + ret := elem.Types(qtype) // there can only be one of these (or zero) switch { case ret != nil: if do { sigs := elem.Types(dns.TypeRRSIG) - sigs = signatureForSubType(sigs, rr.Header().Rrtype) + sigs = signatureForSubType(sigs, qtype) ret = append(ret, sigs...) } - ret = wildcardReplace(rr, ce, ret) + ret = wildcardReplace(qname, ce, ret) return ret, nil, nil, Success case ret == nil: // nodata, nsec from the wildcard - type does not exist @@ -106,7 +90,7 @@ func (z *Zone) nameError(rr dns.RR, do bool) ([]dns.RR, []dns.RR, []dns.RR, Resu ret := []dns.RR{z.SOA} if do { ret = append(ret, z.SIG...) - ret = append(ret, z.nameErrorProof(rr)...) + ret = append(ret, z.nameErrorProof(qname, qtype)...) } return nil, ret, nil, NameError } @@ -135,15 +119,15 @@ func (z *Zone) lookupNSEC(elem *tree.Elem, do bool) []dns.RR { return nsec } -func (z *Zone) lookupCNAME(rrs []dns.RR, rr dns.RR, do bool) ([]dns.RR, []dns.RR, []dns.RR, Result) { - elem, _ := z.Tree.Get(rr) +func (z *Zone) lookupCNAME(rrs []dns.RR, qtype uint16, do bool) ([]dns.RR, []dns.RR, []dns.RR, Result) { + elem, _ := z.Tree.Search(rrs[0].(*dns.CNAME).Target, qtype) if elem == nil { return rrs, nil, nil, Success } - extra := cnameForType(elem.All(), rr.Header().Rrtype) + extra := cnameForType(elem.All(), qtype) if do { sigs := elem.Types(dns.TypeRRSIG) - sigs = signatureForSubType(sigs, rr.Header().Rrtype) + sigs = signatureForSubType(sigs, qtype) if len(sigs) > 0 { extra = append(extra, sigs...) } @@ -175,25 +159,13 @@ func signatureForSubType(rrs []dns.RR, subtype uint16) []dns.RR { return sigs } -// wildcardReplace replaces the first wildcard with label. -func wildcardReplace(rr dns.RR, ce string, rrs []dns.RR) []dns.RR { - // Get how many labels the ce is off from the fullname, this is how much of the - // original rr's '*' we must replace. - labels := dns.CountLabel(rr.Header().Name) - dns.CountLabel(ce) // can not be 0, TODO(miek): check - - indexes := dns.Split(rr.Header().Name) - if labels >= len(indexes) { - // TODO(miek): yes then what? - // Is the == right here? - return nil - } - replacement := rr.Header().Name[:indexes[labels]] - +// wildcardReplace replaces the ownername with the original query name. +func wildcardReplace(qname, ce string, rrs []dns.RR) []dns.RR { // need to copy here, otherwise we change in zone stuff ret := make([]dns.RR, len(rrs)) for i, r := range rrs { ret[i] = dns.Copy(r) - ret[i].Header().Name = replacement + r.Header().Name[2:] + ret[i].Header().Name = qname } return ret } diff --git a/middleware/file/tree/elem.go b/middleware/file/tree/elem.go index 4008e8380..8698a9317 100644 --- a/middleware/file/tree/elem.go +++ b/middleware/file/tree/elem.go @@ -91,9 +91,8 @@ func (e *Elem) Delete(rr dns.RR) (empty bool) { return } -func Less(a *Elem, rr dns.RR) int { - return middleware.Less(rr.Header().Name, a.Name()) -} +// Less is a tree helper function that calles middleware.Less. +func Less(a *Elem, name string) int { return middleware.Less(name, a.Name()) } // Assuming the same type and name this will check if the rdata is equal as well. func equalRdata(a, b dns.RR) bool { diff --git a/middleware/file/tree/tree.go b/middleware/file/tree/tree.go index d0ecfcd94..c3e437ec4 100644 --- a/middleware/file/tree/tree.go +++ b/middleware/file/tree/tree.go @@ -23,7 +23,7 @@ const ( BU23 ) -// Result is a result of a Get lookup. +// Result is a result of a Search. type Result int const ( @@ -149,22 +149,22 @@ func (t *Tree) Len() int { return t.Count } -// Get returns the first match of rr in the Tree. -func (t *Tree) Get(rr dns.RR) (*Elem, Result) { +// Search returns the first match of qname/qtype in the Tree. +func (t *Tree) Search(qname string, qtype uint16) (*Elem, Result) { if t.Root == nil { return nil, NameError } - n, res := t.Root.search(rr) + n, res := t.Root.search(qname, qtype) if n == nil { return nil, res } return n.Elem, res } -func (n *Node) search(rr dns.RR) (*Node, Result) { +func (n *Node) search(qname string, qtype uint16) (*Node, Result) { old := n for n != nil { - switch c := Less(n.Elem, rr); { + switch c := Less(n.Elem, qname); { case c == 0: return n, Found case c < 0: @@ -175,7 +175,7 @@ func (n *Node) search(rr dns.RR) (*Node, Result) { n = n.Right } } - if dns.CountLabel(rr.Header().Name) < dns.CountLabel(old.Elem.Name()) { + if dns.CountLabel(qname) < dns.CountLabel(old.Elem.Name()) { return n, EmptyNonTerminal } @@ -205,7 +205,7 @@ func (n *Node) insert(rr dns.RR) (root *Node, d int) { } } - switch c := Less(n.Elem, rr); { + switch c := Less(n.Elem, rr.Header().Name); { case c == 0: n.Elem.Insert(rr) case c < 0: @@ -297,7 +297,7 @@ func (t *Tree) Delete(rr dns.RR) { return } - el, _ := t.Get(rr) + el, _ := t.Search(rr.Header().Name, rr.Header().Rrtype) if el == nil { t.DeleteNode(rr) return @@ -325,7 +325,7 @@ func (t *Tree) DeleteNode(rr dns.RR) { } func (n *Node) delete(rr dns.RR) (root *Node, d int) { - if Less(n.Elem, rr) < 0 { + if Less(n.Elem, rr.Header().Name) < 0 { if n.Left != nil { if n.Left.color() == Black && n.Left.Left.color() == Black { n = n.moveRedLeft() @@ -336,14 +336,14 @@ func (n *Node) delete(rr dns.RR) (root *Node, d int) { if n.Left.color() == Red { n = n.rotateRight() } - if n.Right == nil && Less(n.Elem, rr) == 0 { + if n.Right == nil && Less(n.Elem, rr.Header().Name) == 0 { return nil, -1 } if n.Right != nil { if n.Right.color() == Black && n.Right.Left.color() == Black { n = n.moveRedRight() } - if Less(n.Elem, rr) == 0 { + if Less(n.Elem, rr.Header().Name) == 0 { n.Elem = n.Right.min().Elem n.Right, d = n.Right.deleteMin() } else { @@ -384,58 +384,58 @@ func (n *Node) max() *Node { return n } -// Prev returns the greatest value equal to or less than the rr according to Less(). -func (t *Tree) Prev(rr dns.RR) *Elem { +// Prev returns the greatest value equal to or less than the qname according to Less(). +func (t *Tree) Prev(qname string) *Elem { if t.Root == nil { return nil } - n := t.Root.floor(rr) + n := t.Root.floor(qname) if n == nil { return nil } return n.Elem } -func (n *Node) floor(rr dns.RR) *Node { +func (n *Node) floor(qname string) *Node { if n == nil { return nil } - switch c := Less(n.Elem, rr); { + switch c := Less(n.Elem, qname); { case c == 0: return n case c < 0: - return n.Left.floor(rr) + return n.Left.floor(qname) default: - if r := n.Right.floor(rr); r != nil { + if r := n.Right.floor(qname); r != nil { return r } } return n } -// Next returns the smallest value equal to or greater than the rr according to Less(). -func (t *Tree) Next(rr dns.RR) *Elem { +// Next returns the smallest value equal to or greater than the qname according to Less(). +func (t *Tree) Next(qname string) *Elem { if t.Root == nil { return nil } - n := t.Root.ceil(rr) + n := t.Root.ceil(qname) if n == nil { return nil } return n.Elem } -func (n *Node) ceil(rr dns.RR) *Node { +func (n *Node) ceil(qname string) *Node { if n == nil { return nil } - switch c := Less(n.Elem, rr); { + switch c := Less(n.Elem, qname); { case c == 0: return n case c > 0: - return n.Right.ceil(rr) + return n.Right.ceil(qname) default: - if l := n.Left.ceil(rr); l != nil { + if l := n.Left.ceil(qname); l != nil { return l } }