package context

import (
	"net/http"
	"net/http/httptest"
	"net/http/httputil"
	"net/url"
	"reflect"
	"testing"
	"time"
)

func TestWithRequest(t *testing.T) {
	var req http.Request

	start := time.Now()
	req.Method = "GET"
	req.Host = "example.com"
	req.RequestURI = "/test-test"
	req.Header = make(http.Header)
	req.Header.Set("Referer", "foo.com/referer")
	req.Header.Set("User-Agent", "test/0.1")

	ctx := WithRequest(Background(), &req)
	for _, testcase := range []struct {
		key      string
		expected interface{}
	}{
		{
			key:      "http.request",
			expected: &req,
		},
		{
			key: "http.request.id",
		},
		{
			key:      "http.request.method",
			expected: req.Method,
		},
		{
			key:      "http.request.host",
			expected: req.Host,
		},
		{
			key:      "http.request.uri",
			expected: req.RequestURI,
		},
		{
			key:      "http.request.referer",
			expected: req.Referer(),
		},
		{
			key:      "http.request.useragent",
			expected: req.UserAgent(),
		},
		{
			key:      "http.request.remoteaddr",
			expected: req.RemoteAddr,
		},
		{
			key: "http.request.startedat",
		},
	} {
		v := ctx.Value(testcase.key)

		if v == nil {
			t.Fatalf("value not found for %q", testcase.key)
		}

		if testcase.expected != nil && v != testcase.expected {
			t.Fatalf("%s: %v != %v", testcase.key, v, testcase.expected)
		}

		// Key specific checks!
		switch testcase.key {
		case "http.request.id":
			if _, ok := v.(string); !ok {
				t.Fatalf("request id not a string: %v", v)
			}
		case "http.request.startedat":
			vt, ok := v.(time.Time)
			if !ok {
				t.Fatalf("value not a time: %v", v)
			}

			now := time.Now()
			if vt.After(now) {
				t.Fatalf("time generated too late: %v > %v", vt, now)
			}

			if vt.Before(start) {
				t.Fatalf("time generated too early: %v < %v", vt, start)
			}
		}
	}
}

type testResponseWriter struct {
	flushed bool
	status  int
	written int64
	header  http.Header
}

func (trw *testResponseWriter) Header() http.Header {
	if trw.header == nil {
		trw.header = make(http.Header)
	}

	return trw.header
}

func (trw *testResponseWriter) Write(p []byte) (n int, err error) {
	if trw.status == 0 {
		trw.status = http.StatusOK
	}

	n = len(p)
	trw.written += int64(n)
	return
}

func (trw *testResponseWriter) WriteHeader(status int) {
	trw.status = status
}

func (trw *testResponseWriter) Flush() {
	trw.flushed = true
}

func TestWithResponseWriter(t *testing.T) {
	trw := testResponseWriter{}
	ctx, rw := WithResponseWriter(Background(), &trw)

	if ctx.Value("http.response") != rw {
		t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), rw)
	}

	grw, err := GetResponseWriter(ctx)
	if err != nil {
		t.Fatalf("error getting response writer: %v", err)
	}

	if grw != rw {
		t.Fatalf("unexpected response writer returned: %#v != %#v", grw, rw)
	}

	if ctx.Value("http.response.status") != 0 {
		t.Fatalf("response status should always be a number and should be zero here: %v != 0", ctx.Value("http.response.status"))
	}

	if n, err := rw.Write(make([]byte, 1024)); err != nil {
		t.Fatalf("unexpected error writing: %v", err)
	} else if n != 1024 {
		t.Fatalf("unexpected number of bytes written: %v != %v", n, 1024)
	}

	if ctx.Value("http.response.status") != http.StatusOK {
		t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusOK)
	}

	if ctx.Value("http.response.written") != int64(1024) {
		t.Fatalf("unexpected number reported bytes written: %v != %v", ctx.Value("http.response.written"), 1024)
	}

	// Make sure flush propagates
	rw.(http.Flusher).Flush()

	if !trw.flushed {
		t.Fatalf("response writer not flushed")
	}

	// Write another status and make sure context is correct. This normally
	// wouldn't work except for in this contrived testcase.
	rw.WriteHeader(http.StatusBadRequest)

	if ctx.Value("http.response.status") != http.StatusBadRequest {
		t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusBadRequest)
	}
}

func TestWithVars(t *testing.T) {
	var req http.Request
	vars := map[string]string{
		"foo": "asdf",
		"bar": "qwer",
	}

	getVarsFromRequest = func(r *http.Request) map[string]string {
		if r != &req {
			t.Fatalf("unexpected request: %v != %v", r, req)
		}

		return vars
	}

	ctx := WithVars(Background(), &req)
	for _, testcase := range []struct {
		key      string
		expected interface{}
	}{
		{
			key:      "vars",
			expected: vars,
		},
		{
			key:      "vars.foo",
			expected: "asdf",
		},
		{
			key:      "vars.bar",
			expected: "qwer",
		},
	} {
		v := ctx.Value(testcase.key)

		if !reflect.DeepEqual(v, testcase.expected) {
			t.Fatalf("%q: %v != %v", testcase.key, v, testcase.expected)
		}
	}
}

// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
// RemoteAddr().  A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
// at the transport layer to 127.0.0.1:<port> .  However, as the X-Forwarded-For header
// just contains the IP address, it is different enough for testing.
func TestRemoteAddr(t *testing.T) {
	var expectedRemote string
	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		defer r.Body.Close()

		if r.RemoteAddr == expectedRemote {
			t.Errorf("Unexpected matching remote addresses")
		}

		actualRemote := RemoteAddr(r)
		if expectedRemote != actualRemote {
			t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
		}

		w.WriteHeader(200)
	}))

	defer backend.Close()
	backendURL, err := url.Parse(backend.URL)
	if err != nil {
		t.Fatal(err)
	}

	proxy := httputil.NewSingleHostReverseProxy(backendURL)
	frontend := httptest.NewServer(proxy)
	defer frontend.Close()

	// X-Forwarded-For set by proxy
	expectedRemote = "127.0.0.1"
	proxyReq, err := http.NewRequest("GET", frontend.URL, nil)
	if err != nil {
		t.Fatal(err)
	}

	_, err = http.DefaultClient.Do(proxyReq)
	if err != nil {
		t.Fatal(err)
	}

	// RemoteAddr in X-Real-Ip
	getReq, err := http.NewRequest("GET", backend.URL, nil)
	if err != nil {
		t.Fatal(err)
	}

	expectedRemote = "1.2.3.4"
	getReq.Header["X-Real-ip"] = []string{expectedRemote}
	_, err = http.DefaultClient.Do(getReq)
	if err != nil {
		t.Fatal(err)
	}

	// Valid X-Real-Ip and invalid X-Forwarded-For
	getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
	_, err = http.DefaultClient.Do(getReq)
	if err != nil {
		t.Fatal(err)
	}
}