diff --git a/plugin/cache/cache_test.go b/plugin/cache/cache_test.go index 4afaf73c4..b32353372 100644 --- a/plugin/cache/cache_test.go +++ b/plugin/cache/cache_test.go @@ -191,7 +191,7 @@ func TestCache(t *testing.T) { m := tc.in.Msg() m = cacheMsg(m, tc) - state := request.Request{W: nil, Req: m} + state := request.Request{W: &test.ResponseWriter{}, Req: m} mt, _ := response.Typify(m, utc) valid, k := key(state.Name(), m, mt, state.Do()) diff --git a/plugin/pkg/edns/edns.go b/plugin/pkg/edns/edns.go index 68fb03865..31f57ea9b 100644 --- a/plugin/pkg/edns/edns.go +++ b/plugin/pkg/edns/edns.go @@ -63,7 +63,7 @@ func Version(req *dns.Msg) (*dns.Msg, error) { } // Size returns a normalized size based on proto. -func Size(proto string, size int) int { +func Size(proto string, size uint16) uint16 { if proto == "tcp" { return dns.MaxMsgSize } diff --git a/request/request.go b/request/request.go index 6f1a1de0e..76bb6a787 100644 --- a/request/request.go +++ b/request/request.go @@ -18,15 +18,16 @@ type Request struct { // Optional lowercased zone of this query. Zone string - // Cache size after first call to Size or Do. - size int - do *bool // nil: nothing, otherwise *do value + // Cache size after first call to Size or Do. If size is zero nothing has been cached yet. + // Both Size and Do set these values (and cache them). + size uint16 // UDP buffer size, or 64K in case of TCP. + do bool // DNSSEC OK value // Caches + family int8 // transport's family. name string // lowercase qname. ip string // client's ip. port string // client's port. - family int // transport's family. localPort string // server's port. localIP string // server's ip. } @@ -127,7 +128,7 @@ func Proto(w dns.ResponseWriter) string { // Family returns the family of the transport, 1 for IPv4 and 2 for IPv6. func (r *Request) Family() int { if r.family != 0 { - return r.family + return int(r.family) } var a net.IP @@ -141,26 +142,20 @@ func (r *Request) Family() int { if a.To4() != nil { r.family = 1 - return r.family + return 1 } r.family = 2 - return r.family + return 2 } // Do returns if the request has the DO (DNSSEC OK) bit set. func (r *Request) Do() bool { - if r.do != nil { - return *r.do + if r.size != 0 { + return r.do } - r.do = new(bool) - - if o := r.Req.IsEdns0(); o != nil { - *r.do = o.Do() - return *r.do - } - *r.do = false - return false + r.Size() + return r.do } // Len returns the length in bytes in the request. @@ -170,21 +165,19 @@ func (r *Request) Len() int { return r.Req.Len() } // Or when the request was over TCP, we return the maximum allowed size of 64K. func (r *Request) Size() int { if r.size != 0 { - return r.size + return int(r.size) } - size := 0 + size := uint16(0) if o := r.Req.IsEdns0(); o != nil { - if r.do == nil { - r.do = new(bool) - } - *r.do = o.Do() - size = int(o.UDPSize()) + r.do = o.Do() + size = o.UDPSize() } + // normalize size size = edns.Size(r.Proto(), size) r.size = size - return size + return int(size) } // SizeAndDo adds an OPT record that the reflects the intent from request. diff --git a/request/request_test.go b/request/request_test.go index a62fc51bf..0a3b1f2d8 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -13,7 +13,7 @@ func TestRequestDo(t *testing.T) { st := testRequest() st.Do() - if st.do == nil { + if !st.do { t.Errorf("Expected st.do to be set") } }