package middleware

import (
	"testing"

	"github.com/miekg/coredns/middleware/test"

	"github.com/miekg/dns"
)

func TestStateDo(t *testing.T) {
	st := testState()

	st.Do()
	if st.do == 0 {
		t.Fatalf("expected st.do to be set")
	}
}

func TestStateRemote(t *testing.T) {
	st := testState()
	if st.IP() != "10.240.0.1" {
		t.Fatalf("wrong IP from state")
	}
	p, err := st.Port()
	if err != nil {
		t.Fatalf("failed to get Port from state")
	}
	if p != "40212" {
		t.Fatalf("wrong port from state")
	}
}

func BenchmarkStateDo(b *testing.B) {
	st := testState()

	for i := 0; i < b.N; i++ {
		st.Do()
	}
}

func BenchmarkStateSize(b *testing.B) {
	st := testState()

	for i := 0; i < b.N; i++ {
		st.Size()
	}
}

func testState() State {
	m := new(dns.Msg)
	m.SetQuestion("example.com.", dns.TypeA)
	m.SetEdns0(4097, true)
	return State{W: &test.ResponseWriter{}, Req: m}
}

/*
func TestHeader(t *testing.T) {
	state := getContextOrFail(t)

	headerKey, headerVal := "Header1", "HeaderVal1"
	state.Req.Header.Add(headerKey, headerVal)

	actualHeaderVal := state.Header(headerKey)
	if actualHeaderVal != headerVal {
		t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
	}

	missingHeaderVal := state.Header("not-existing")
	if missingHeaderVal != "" {
		t.Errorf("Expected empty header value, found %s", missingHeaderVal)
	}
}

func TestIP(t *testing.T) {
	state := getContextOrFail(t)

	tests := []struct {
		inputRemoteAddr string
		expectedIP      string
	}{
		// Test 0 - ipv4 with port
		{"1.1.1.1:1111", "1.1.1.1"},
		// Test 1 - ipv4 without port
		{"1.1.1.1", "1.1.1.1"},
		// Test 2 - ipv6 with port
		{"[::1]:11", "::1"},
		// Test 3 - ipv6 without port and brackets
		{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
		// Test 4 - ipv6 with zone and port
		{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
	}

	for i, test := range tests {
		testPrefix := getTestPrefix(i)

		state.Req.RemoteAddr = test.inputRemoteAddr
		actualIP := state.IP()

		if actualIP != test.expectedIP {
			t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
		}
	}
}

func TestURL(t *testing.T) {
	state := getContextOrFail(t)

	inputURL := "http://localhost"
	state.Req.RequestURI = inputURL

	if inputURL != state.URI() {
		t.Errorf("Expected url %s, found %s", inputURL, state.URI())
	}
}

func TestHost(t *testing.T) {
	tests := []struct {
		input        string
		expectedHost string
		shouldErr    bool
	}{
		{
			input:        "localhost:123",
			expectedHost: "localhost",
			shouldErr:    false,
		},
		{
			input:        "localhost",
			expectedHost: "localhost",
			shouldErr:    false,
		},
		{
			input:        "[::]",
			expectedHost: "",
			shouldErr:    true,
		},
	}

	for _, test := range tests {
		testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
	}
}

func TestPort(t *testing.T) {
	tests := []struct {
		input        string
		expectedPort string
		shouldErr    bool
	}{
		{
			input:        "localhost:123",
			expectedPort: "123",
			shouldErr:    false,
		},
		{
			input:        "localhost",
			expectedPort: "80", // assuming 80 is the default port
			shouldErr:    false,
		},
		{
			input:        ":8080",
			expectedPort: "8080",
			shouldErr:    false,
		},
		{
			input:        "[::]",
			expectedPort: "",
			shouldErr:    true,
		},
	}

	for _, test := range tests {
		testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
	}
}

func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
	state := getContextOrFail(t)

	state.Req.Host = input
	var actualResult, testedObject string
	var err error

	if isTestingHost {
		actualResult, err = state.Host()
		testedObject = "host"
	} else {
		actualResult, err = state.Port()
		testedObject = "port"
	}

	if shouldErr && err == nil {
		t.Errorf("Expected error, found nil!")
		return
	}

	if !shouldErr && err != nil {
		t.Errorf("Expected no error, found %s", err)
		return
	}

	if actualResult != expectedResult {
		t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
	}
}

func TestPathMatches(t *testing.T) {
	state := getContextOrFail(t)

	tests := []struct {
		urlStr      string
		pattern     string
		shouldMatch bool
	}{
		// Test 0
		{
			urlStr:      "http://localhost/",
			pattern:     "",
			shouldMatch: true,
		},
		// Test 1
		{
			urlStr:      "http://localhost",
			pattern:     "",
			shouldMatch: true,
		},
		// Test 1
		{
			urlStr:      "http://localhost/",
			pattern:     "/",
			shouldMatch: true,
		},
		// Test 3
		{
			urlStr:      "http://localhost/?param=val",
			pattern:     "/",
			shouldMatch: true,
		},
		// Test 4
		{
			urlStr:      "http://localhost/dir1/dir2",
			pattern:     "/dir2",
			shouldMatch: false,
		},
		// Test 5
		{
			urlStr:      "http://localhost/dir1/dir2",
			pattern:     "/dir1",
			shouldMatch: true,
		},
		// Test 6
		{
			urlStr:      "http://localhost:444/dir1/dir2",
			pattern:     "/dir1",
			shouldMatch: true,
		},
	}

	for i, test := range tests {
		testPrefix := getTestPrefix(i)
		var err error
		state.Req.URL, err = url.Parse(test.urlStr)
		if err != nil {
			t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
		}

		matches := state.PathMatches(test.pattern)
		if matches != test.shouldMatch {
			t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
		}
	}
}

func initTestContext() (Context, error) {
	body := bytes.NewBufferString("request body")
	request, err := http.NewRequest("GET", "https://localhost", body)
	if err != nil {
		return Context{}, err
	}

	return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
}


func getTestPrefix(testN int) string {
	return fmt.Sprintf("Test [%d]: ", testN)
}
*/