diff --git a/plugin/file/delete_test.go b/plugin/file/delete_test.go new file mode 100644 index 000000000..26ee64e3a --- /dev/null +++ b/plugin/file/delete_test.go @@ -0,0 +1,65 @@ +package file + +import ( + "bytes" + "fmt" + "testing" + + "github.com/coredns/coredns/plugin/file/tree" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +/* +Create a zone with: + + apex + / + a MX + a A + +Test that: we create the proper tree and that delete +deletes the correct elements +*/ + +var tz = NewZone("example.org.", "db.example.org.") + +type treebuf struct { + *bytes.Buffer +} + +func (t *treebuf) printFunc(e *tree.Elem, rrs map[uint16][]dns.RR) error { + fmt.Fprintf(t.Buffer, "%v\n", rrs) // should be fixed order in new go versions. + return nil +} + +func TestZoneInsertAndDelete(t *testing.T) { + tz.Insert(test.SOA("example.org. IN SOA 1 2 3 4 5")) + + if x := tz.Apex.SOA.Header().Name; x != "example.org." { + t.Errorf("Failed to insert SOA, expected %s, git %s", "example.org.", x) + } + + // Insert two RRs and then remove one. + tz.Insert(test.A("a.example.org. IN A 127.0.0.1")) + tz.Insert(test.MX("a.example.org. IN MX 10 mx.example.org.")) + + tz.Delete(test.MX("a.example.org. IN MX 10 mx.example.org.")) + + tb := treebuf{new(bytes.Buffer)} + + tz.Walk(tb.printFunc) + if tb.String() != "map[1:[a.example.org.\t3600\tIN\tA\t127.0.0.1]]\n" { + t.Errorf("Expected 1 A record in tree, got %s", tb.String()) + } + + tz.Delete(test.A("a.example.org. IN A 127.0.0.1")) + + tb.Reset() + + tz.Walk(tb.printFunc) + if tb.String() != "" { + t.Errorf("Expected no record in tree, got %s", tb.String()) + } +} diff --git a/plugin/file/file.go b/plugin/file/file.go index bc582cfaa..169720e64 100644 --- a/plugin/file/file.go +++ b/plugin/file/file.go @@ -119,7 +119,6 @@ func (s *serialErr) Error() string { // If serial >= 0 it will reload the zone, if the SOA hasn't changed // it returns an error indicating nothing was read. func Parse(f io.Reader, origin, fileName string, serial int64) (*Zone, error) { - zp := dns.NewZoneParser(f, dns.Fqdn(origin), fileName) zp.SetIncludeAllowed(true) z := NewZone(origin, fileName) diff --git a/plugin/file/lookup.go b/plugin/file/lookup.go index 3a72a6163..14dfb6f7d 100644 --- a/plugin/file/lookup.go +++ b/plugin/file/lookup.go @@ -106,14 +106,14 @@ func (z *Zone) Lookup(ctx context.Context, state request.Request, qname string) } // If we see DNAME records, we should return those. - if dnamerrs := elem.Types(dns.TypeDNAME); dnamerrs != nil { + if dnamerrs := elem.Type(dns.TypeDNAME); dnamerrs != nil { // Only one DNAME is allowed per name. We just pick the first one to synthesize from. dname := dnamerrs[0] if cname := synthesizeCNAME(state.Name(), dname.(*dns.DNAME)); cname != nil { answer, ns, extra, rcode := z.additionalProcessing(ctx, state, elem, []dns.RR{cname}) if do { - sigs := elem.Types(dns.TypeRRSIG) + sigs := elem.Type(dns.TypeRRSIG) sigs = signatureForSubType(sigs, dns.TypeDNAME) dnamerrs = append(dnamerrs, sigs...) } @@ -130,7 +130,7 @@ func (z *Zone) Lookup(ctx context.Context, state request.Request, qname string) } // If we see NS records, it means the name as been delegated, and we should return the delegation. - if nsrrs := elem.Types(dns.TypeNS); nsrrs != nil { + if nsrrs := elem.Type(dns.TypeNS); nsrrs != nil { // If the query is specifically for DS and the qname matches the delegated name, we should // return the DS in the answer section and leave the rest empty, i.e. just continue the loop @@ -160,11 +160,11 @@ func (z *Zone) Lookup(ctx context.Context, state request.Request, qname string) // Found entire name. if found && shot { - if rrs := elem.Types(dns.TypeCNAME); len(rrs) > 0 && qtype != dns.TypeCNAME { + if rrs := elem.Type(dns.TypeCNAME); len(rrs) > 0 && qtype != dns.TypeCNAME { return z.additionalProcessing(ctx, state, elem, rrs) } - rrs := elem.Types(qtype, qname) + rrs := elem.Type(qtype) // NODATA if len(rrs) == 0 { @@ -181,7 +181,7 @@ func (z *Zone) Lookup(ctx context.Context, state request.Request, qname string) additional := additionalProcessing(z, rrs, do) if do { - sigs := elem.Types(dns.TypeRRSIG) + sigs := elem.Type(dns.TypeRRSIG) sigs = signatureForSubType(sigs, qtype) rrs = append(rrs, sigs...) } @@ -196,11 +196,11 @@ func (z *Zone) Lookup(ctx context.Context, state request.Request, qname string) if wildElem != nil { auth := z.ns(do) - if rrs := wildElem.Types(dns.TypeCNAME, qname); len(rrs) > 0 { + if rrs := wildElem.TypeForWildcard(dns.TypeCNAME, qname); len(rrs) > 0 { return z.additionalProcessing(ctx, state, wildElem, rrs) } - rrs := wildElem.Types(qtype, qname) + rrs := wildElem.TypeForWildcard(qtype, qname) // NODATA response. if len(rrs) == 0 { @@ -219,7 +219,7 @@ func (z *Zone) Lookup(ctx context.Context, state request.Request, qname string) auth = append(auth, nsec...) } - sigs := wildElem.Types(dns.TypeRRSIG, qname) + sigs := wildElem.TypeForWildcard(dns.TypeRRSIG, qname) sigs = signatureForSubType(sigs, qtype) rrs = append(rrs, sigs...) @@ -272,9 +272,9 @@ Out: // Return type tp from e and add signatures (if they exists) and do is true. func (z *Zone) typeFromElem(elem *tree.Elem, tp uint16, do bool) []dns.RR { - rrs := elem.Types(tp) + rrs := elem.Type(tp) if do { - sigs := elem.Types(dns.TypeRRSIG) + sigs := elem.Type(dns.TypeRRSIG) sigs = signatureForSubType(sigs, tp) if len(sigs) > 0 { rrs = append(rrs, sigs...) @@ -306,7 +306,7 @@ func (z *Zone) additionalProcessing(ctx context.Context, state request.Request, do := state.Do() if do { - sigs := elem.Types(dns.TypeRRSIG) + sigs := elem.Type(dns.TypeRRSIG) sigs = signatureForSubType(sigs, dns.TypeCNAME) if len(sigs) > 0 { rrs = append(rrs, sigs...) @@ -323,12 +323,12 @@ func (z *Zone) additionalProcessing(ctx context.Context, state request.Request, i := 0 Redo: - cname := elem.Types(dns.TypeCNAME) + cname := elem.Type(dns.TypeCNAME) if len(cname) > 0 { rrs = append(rrs, cname...) if do { - sigs := elem.Types(dns.TypeRRSIG) + sigs := elem.Type(dns.TypeRRSIG) sigs = signatureForSubType(sigs, dns.TypeCNAME) if len(sigs) > 0 { rrs = append(rrs, sigs...) @@ -354,7 +354,7 @@ Redo: rrs = append(rrs, targets...) if do { - sigs := elem.Types(dns.TypeRRSIG) + sigs := elem.Type(dns.TypeRRSIG) sigs = signatureForSubType(sigs, qtype) if len(sigs) > 0 { rrs = append(rrs, sigs...) @@ -416,9 +416,9 @@ func (z *Zone) searchGlue(name string, do bool) []dns.RR { // A if elem, found := z.Tree.Search(name); found { - glue = append(glue, elem.Types(dns.TypeA)...) + glue = append(glue, elem.Type(dns.TypeA)...) if do { - sigs := elem.Types(dns.TypeRRSIG) + sigs := elem.Type(dns.TypeRRSIG) sigs = signatureForSubType(sigs, dns.TypeA) glue = append(glue, sigs...) } @@ -426,9 +426,9 @@ func (z *Zone) searchGlue(name string, do bool) []dns.RR { // AAAA if elem, found := z.Tree.Search(name); found { - glue = append(glue, elem.Types(dns.TypeAAAA)...) + glue = append(glue, elem.Type(dns.TypeAAAA)...) if do { - sigs := elem.Types(dns.TypeRRSIG) + sigs := elem.Type(dns.TypeRRSIG) sigs = signatureForSubType(sigs, dns.TypeAAAA) glue = append(glue, sigs...) } @@ -456,9 +456,9 @@ func additionalProcessing(z *Zone, answer []dns.RR, do bool) (extra []dns.RR) { continue } - sigs := elem.Types(dns.TypeRRSIG) + sigs := elem.Type(dns.TypeRRSIG) for _, addr := range []uint16{dns.TypeA, dns.TypeAAAA} { - if a := elem.Types(addr); a != nil { + if a := elem.Type(addr); a != nil { extra = append(extra, a...) if do { sig := signatureForSubType(sigs, addr) diff --git a/plugin/file/tree/all.go b/plugin/file/tree/all.go index fd806365f..e1fc5b392 100644 --- a/plugin/file/tree/all.go +++ b/plugin/file/tree/all.go @@ -1,6 +1,6 @@ package tree -// All traverses tree and returns all elements +// All traverses tree and returns all elements. func (t *Tree) All() []*Elem { if t.Root == nil { return nil @@ -19,30 +19,3 @@ func (n *Node) all(found []*Elem) []*Elem { } return found } - -// Do performs fn on all values stored in the tree. A boolean is returned indicating whether the -// Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort -// relationships, future tree operation behaviors are undefined. -func (t *Tree) Do(fn func(e *Elem) bool) bool { - if t.Root == nil { - return false - } - return t.Root.do(fn) -} - -func (n *Node) do(fn func(e *Elem) bool) (done bool) { - if n.Left != nil { - done = n.Left.do(fn) - if done { - return - } - } - done = fn(n.Elem) - if done { - return - } - if n.Right != nil { - done = n.Right.do(fn) - } - return -} diff --git a/plugin/file/tree/elem.go b/plugin/file/tree/elem.go index 6317cc912..c1909649d 100644 --- a/plugin/file/tree/elem.go +++ b/plugin/file/tree/elem.go @@ -15,20 +15,34 @@ func newElem(rr dns.RR) *Elem { return &e } -// Types returns the RRs with type qtype from e. If qname is given (only the -// first one is used), the RR are copied and the owner is replaced with qname[0]. -func (e *Elem) Types(qtype uint16, qname ...string) []dns.RR { +// Types returns the types of the records in e. The returned list is not sorted. +func (e *Elem) Types() []uint16 { + t := make([]uint16, len(e.m)) + i := 0 + for ty := range e.m { + t[i] = ty + i++ + } + return t +} + +// Type returns the RRs with type qtype from e. +func (e *Elem) Type(qtype uint16) []dns.RR { return e.m[qtype] } + +// TypeForWildcard returns the RRs with type qtype from e. The ownername returned is set to qname. +func (e *Elem) TypeForWildcard(qtype uint16, qname string) []dns.RR { rrs := e.m[qtype] - if rrs != nil && len(qname) > 0 { - copied := make([]dns.RR, len(rrs)) - for i := range rrs { - copied[i] = dns.Copy(rrs[i]) - copied[i].Header().Name = qname[0] - } - return copied + if rrs == nil { + return nil } - return rrs + + copied := make([]dns.RR, len(rrs)) + for i := range rrs { + copied[i] = dns.Copy(rrs[i]) + copied[i].Header().Name = qname + } + return copied } // All returns all RRs from e, regardless of type. @@ -52,13 +66,10 @@ func (e *Elem) Name() string { return "" } -// Empty returns true is e does not contain any RRs, i.e. is an -// empty-non-terminal. -func (e *Elem) Empty() bool { - return len(e.m) == 0 -} +// Empty returns true is e does not contain any RRs, i.e. is an empty-non-terminal. +func (e *Elem) Empty() bool { return len(e.m) == 0 } -// Insert inserts rr into e. If rr is equal to existing rrs this is a noop. +// Insert inserts rr into e. If rr is equal to existing RRs, the RR will be added anyway. func (e *Elem) Insert(rr dns.RR) { t := rr.Header().Rrtype if e.m == nil { @@ -71,66 +82,20 @@ func (e *Elem) Insert(rr dns.RR) { e.m[t] = []dns.RR{rr} return } - for _, er := range rrs { - if equalRdata(er, rr) { - return - } - } rrs = append(rrs, rr) e.m[t] = rrs } -// Delete removes rr from e. When e is empty after the removal the returned bool is true. -func (e *Elem) Delete(rr dns.RR) (empty bool) { +// Delete removes all RRs of type rr.Header().Rrtype from e. +func (e *Elem) Delete(rr dns.RR) { if e.m == nil { - return true - } - - t := rr.Header().Rrtype - rrs, ok := e.m[t] - if !ok { return } - for i, er := range rrs { - if equalRdata(er, rr) { - rrs = removeFromSlice(rrs, i) - e.m[t] = rrs - empty = len(rrs) == 0 - if empty { - delete(e.m, t) - } - return - } - } - return + t := rr.Header().Rrtype + delete(e.m, t) } // Less is a tree helper function that calls less. func Less(a *Elem, name string) int { return less(name, a.Name()) } - -// Assuming the same type and name this will check if the rdata is equal as well. -func equalRdata(a, b dns.RR) bool { - switch x := a.(type) { - // TODO(miek): more types, i.e. all types. + tests for this. - case *dns.A: - return x.A.Equal(b.(*dns.A).A) - case *dns.AAAA: - return x.AAAA.Equal(b.(*dns.AAAA).AAAA) - case *dns.MX: - if x.Mx == b.(*dns.MX).Mx && x.Preference == b.(*dns.MX).Preference { - return true - } - } - return false -} - -// removeFromSlice removes index i from the slice. -func removeFromSlice(rrs []dns.RR, i int) []dns.RR { - if i >= len(rrs) { - return rrs - } - rrs = append(rrs[:i], rrs[i+1:]...) - return rrs -} diff --git a/plugin/file/tree/tree.go b/plugin/file/tree/tree.go index ed33c09a4..3aeeba4d5 100644 --- a/plugin/file/tree/tree.go +++ b/plugin/file/tree/tree.go @@ -275,7 +275,8 @@ func (n *Node) deleteMax() (root *Node, d int) { return } -// Delete removes rr from the tree, is the node turns empty, that node is deleted with DeleteNode. +// Delete removes all RRs of type rr.Header().Rrtype from e. If after the deletion of rr the node is empty the +// entire node is deleted. func (t *Tree) Delete(rr dns.RR) { if t.Root == nil { return @@ -283,15 +284,13 @@ func (t *Tree) Delete(rr dns.RR) { el, _ := t.Search(rr.Header().Name) if el == nil { - t.deleteNode(rr) return } - // Delete from this element. - empty := el.Delete(rr) - if empty { + el.Delete(rr) + if el.Empty() { t.deleteNode(rr) - return } + return } // DeleteNode deletes the node that matches rr according to Less(). diff --git a/plugin/file/tree/walk.go b/plugin/file/tree/walk.go new file mode 100644 index 000000000..00accbafa --- /dev/null +++ b/plugin/file/tree/walk.go @@ -0,0 +1,30 @@ +package tree + +import "github.com/miekg/dns" + +// Walk performs fn on all values stored in the tree. If a non-nil error is returned the +// Walk was interrupted by an fn returning that error. If fn alters stored values' sort +// relationships, future tree operation behaviors are undefined. +func (t *Tree) Walk(fn func(e *Elem, rrs map[uint16][]dns.RR) error) error { + if t.Root == nil { + return nil + } + return t.Root.walk(fn) +} + +func (n *Node) walk(fn func(e *Elem, rrs map[uint16][]dns.RR) error) error { + if n.Left != nil { + if err := n.Left.walk(fn); err != nil { + return err + } + } + if err := fn(n.Elem, n.Elem.m); err != nil { + return err + } + if n.Right != nil { + if err := n.Right.walk(fn); err != nil { + return err + } + } + return nil +}