diff --git a/request/request.go b/request/request.go index 36ec85731..06f840f89 100644 --- a/request/request.go +++ b/request/request.go @@ -32,7 +32,7 @@ type Request struct { port string // client's port. family int // transport's family. localPort string // server's port. - // TODO(miek): localIP once that is merged. + localIP string // server's ip. } // NewWithQuestion returns a new request based on the old, but with a new question @@ -61,11 +61,18 @@ func (r *Request) IP() string { // LocalIP gets the (local) IP address of server handling the request. func (r *Request) LocalIP() string { + if r.localIP != "" { + return r.localIP + } + ip, _, err := net.SplitHostPort(r.W.LocalAddr().String()) if err != nil { - return r.W.LocalAddr().String() + r.localIP = r.W.LocalAddr().String() + return r.localIP } - return ip + + r.localIP = ip + return r.localIP } // Port gets the (remote) port of the client making the request. @@ -423,6 +430,7 @@ func (r *Request) ErrorMessage(rcode int) *dns.Msg { func (r *Request) Clear() { r.name = "" r.ip = "" + r.localIP = "" r.port = "" r.localPort = "" r.family = 0 diff --git a/request/request_test.go b/request/request_test.go index f99e5cd87..c58612605 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -14,21 +14,21 @@ func TestRequestDo(t *testing.T) { st.Do() if st.do == nil { - t.Fatalf("Expected st.do to be set") + t.Errorf("Expected st.do to be set") } } func TestRequestRemote(t *testing.T) { st := testRequest() if st.IP() != "10.240.0.1" { - t.Fatalf("Wrong IP from request") + t.Errorf("Wrong IP from request") } p := st.Port() if p == "" { - t.Fatalf("Failed to get Port from request") + t.Errorf("Failed to get Port from request") } if p != "40212" { - t.Fatalf("Wrong port from request") + t.Errorf("Wrong port from request") } } @@ -202,3 +202,22 @@ func testRequest() Request { m.SetEdns0(4097, true) return Request{W: &test.ResponseWriter{}, Req: m} } + +func TestRequestClear(t *testing.T) { + st := testRequest() + if st.IP() != "10.240.0.1" { + t.Errorf("Wrong IP from request") + } + p := st.Port() + if p == "" { + t.Errorf("Failed to get Port from request") + } + st.Clear() + if st.ip != "" { + t.Errorf("Expected st.ip to be cleared after Clear") + } + + if st.port != "" { + t.Errorf("Expected st.port to be cleared after Clear") + } +}