e21a425f88
Fix lint errors Add more test
274 lines
6 KiB
Go
274 lines
6 KiB
Go
package context
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/net/context"
|
|
)
|
|
|
|
func TestWithRequest(t *testing.T) {
|
|
var req http.Request
|
|
|
|
start := time.Now()
|
|
req.Method = "GET"
|
|
req.Host = "example.com"
|
|
req.RequestURI = "/test-test"
|
|
req.Header = make(http.Header)
|
|
req.Header.Set("Referer", "foo.com/referer")
|
|
req.Header.Set("User-Agent", "test/0.1")
|
|
|
|
ctx := WithRequest(context.Background(), &req)
|
|
for _, testcase := range []struct {
|
|
key string
|
|
expected interface{}
|
|
}{
|
|
{
|
|
key: "http.request",
|
|
expected: &req,
|
|
},
|
|
{
|
|
key: "http.request.id",
|
|
},
|
|
{
|
|
key: "http.request.method",
|
|
expected: req.Method,
|
|
},
|
|
{
|
|
key: "http.request.host",
|
|
expected: req.Host,
|
|
},
|
|
{
|
|
key: "http.request.uri",
|
|
expected: req.RequestURI,
|
|
},
|
|
{
|
|
key: "http.request.referer",
|
|
expected: req.Referer(),
|
|
},
|
|
{
|
|
key: "http.request.useragent",
|
|
expected: req.UserAgent(),
|
|
},
|
|
{
|
|
key: "http.request.remoteaddr",
|
|
expected: req.RemoteAddr,
|
|
},
|
|
{
|
|
key: "http.request.startedat",
|
|
},
|
|
} {
|
|
v := ctx.Value(testcase.key)
|
|
|
|
if v == nil {
|
|
t.Fatalf("value not found for %q", testcase.key)
|
|
}
|
|
|
|
if testcase.expected != nil && v != testcase.expected {
|
|
t.Fatalf("%s: %v != %v", testcase.key, v, testcase.expected)
|
|
}
|
|
|
|
// Key specific checks!
|
|
switch testcase.key {
|
|
case "http.request.id":
|
|
if _, ok := v.(string); !ok {
|
|
t.Fatalf("request id not a string: %v", v)
|
|
}
|
|
case "http.request.startedat":
|
|
vt, ok := v.(time.Time)
|
|
if !ok {
|
|
t.Fatalf("value not a time: %v", v)
|
|
}
|
|
|
|
now := time.Now()
|
|
if vt.After(now) {
|
|
t.Fatalf("time generated too late: %v > %v", vt, now)
|
|
}
|
|
|
|
if vt.Before(start) {
|
|
t.Fatalf("time generated too early: %v < %v", vt, start)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type testResponseWriter struct {
|
|
flushed bool
|
|
status int
|
|
written int64
|
|
header http.Header
|
|
}
|
|
|
|
func (trw *testResponseWriter) Header() http.Header {
|
|
if trw.header == nil {
|
|
trw.header = make(http.Header)
|
|
}
|
|
|
|
return trw.header
|
|
}
|
|
|
|
func (trw *testResponseWriter) Write(p []byte) (n int, err error) {
|
|
if trw.status == 0 {
|
|
trw.status = http.StatusOK
|
|
}
|
|
|
|
n = len(p)
|
|
trw.written += int64(n)
|
|
return
|
|
}
|
|
|
|
func (trw *testResponseWriter) WriteHeader(status int) {
|
|
trw.status = status
|
|
}
|
|
|
|
func (trw *testResponseWriter) Flush() {
|
|
trw.flushed = true
|
|
}
|
|
|
|
func TestWithResponseWriter(t *testing.T) {
|
|
trw := testResponseWriter{}
|
|
ctx, rw := WithResponseWriter(context.Background(), &trw)
|
|
|
|
if ctx.Value("http.response") != &trw {
|
|
t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), &trw)
|
|
}
|
|
|
|
if n, err := rw.Write(make([]byte, 1024)); err != nil {
|
|
t.Fatalf("unexpected error writing: %v", err)
|
|
} else if n != 1024 {
|
|
t.Fatalf("unexpected number of bytes written: %v != %v", n, 1024)
|
|
}
|
|
|
|
if ctx.Value("http.response.status") != http.StatusOK {
|
|
t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusOK)
|
|
}
|
|
|
|
if ctx.Value("http.response.written") != int64(1024) {
|
|
t.Fatalf("unexpected number reported bytes written: %v != %v", ctx.Value("http.response.written"), 1024)
|
|
}
|
|
|
|
// Make sure flush propagates
|
|
rw.(http.Flusher).Flush()
|
|
|
|
if !trw.flushed {
|
|
t.Fatalf("response writer not flushed")
|
|
}
|
|
|
|
// Write another status and make sure context is correct. This normally
|
|
// wouldn't work except for in this contrived testcase.
|
|
rw.WriteHeader(http.StatusBadRequest)
|
|
|
|
if ctx.Value("http.response.status") != http.StatusBadRequest {
|
|
t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusBadRequest)
|
|
}
|
|
}
|
|
|
|
func TestWithVars(t *testing.T) {
|
|
var req http.Request
|
|
vars := map[string]string{
|
|
"foo": "asdf",
|
|
"bar": "qwer",
|
|
}
|
|
|
|
getVarsFromRequest = func(r *http.Request) map[string]string {
|
|
if r != &req {
|
|
t.Fatalf("unexpected request: %v != %v", r, req)
|
|
}
|
|
|
|
return vars
|
|
}
|
|
|
|
ctx := WithVars(context.Background(), &req)
|
|
for _, testcase := range []struct {
|
|
key string
|
|
expected interface{}
|
|
}{
|
|
{
|
|
key: "vars",
|
|
expected: vars,
|
|
},
|
|
{
|
|
key: "vars.foo",
|
|
expected: "asdf",
|
|
},
|
|
{
|
|
key: "vars.bar",
|
|
expected: "qwer",
|
|
},
|
|
} {
|
|
v := ctx.Value(testcase.key)
|
|
|
|
if !reflect.DeepEqual(v, testcase.expected) {
|
|
t.Fatalf("%q: %v != %v", testcase.key, v, testcase.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
|
|
// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
|
|
// at the transport layer to 127.0.0.1:<port> . However, as the X-Forwarded-For header
|
|
// just contains the IP address, it is different enough for testing.
|
|
func TestRemoteAddr(t *testing.T) {
|
|
var expectedRemote string
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
defer r.Body.Close()
|
|
|
|
if r.RemoteAddr == expectedRemote {
|
|
t.Errorf("Unexpected matching remote addresses")
|
|
}
|
|
|
|
actualRemote := RemoteAddr(r)
|
|
if expectedRemote != actualRemote {
|
|
t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
|
|
}
|
|
|
|
w.WriteHeader(200)
|
|
}))
|
|
|
|
defer backend.Close()
|
|
backendURL, err := url.Parse(backend.URL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
proxy := httputil.NewSingleHostReverseProxy(backendURL)
|
|
frontend := httptest.NewServer(proxy)
|
|
defer frontend.Close()
|
|
|
|
// X-Forwarded-For set by proxy
|
|
expectedRemote = "127.0.0.1"
|
|
proxyReq, err := http.NewRequest("GET", frontend.URL, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = http.DefaultClient.Do(proxyReq)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// RemoteAddr in X-Real-Ip
|
|
getReq, err := http.NewRequest("GET", backend.URL, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
expectedRemote = "1.2.3.4"
|
|
getReq.Header["X-Real-ip"] = []string{expectedRemote}
|
|
_, err = http.DefaultClient.Do(getReq)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Valid X-Real-Ip and invalid X-Forwarded-For
|
|
getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
|
|
_, err = http.DefaultClient.Do(getReq)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|