289 lines
5.7 KiB
Go
289 lines
5.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/miekg/coredns/middleware/test"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
func TestStateDo(t *testing.T) {
|
|
st := testState()
|
|
|
|
st.Do()
|
|
if st.do == 0 {
|
|
t.Fatalf("expected st.do to be set")
|
|
}
|
|
}
|
|
|
|
func TestStateRemote(t *testing.T) {
|
|
st := testState()
|
|
if st.IP() != "10.240.0.1" {
|
|
t.Fatalf("wrong IP from state")
|
|
}
|
|
p, err := st.Port()
|
|
if err != nil {
|
|
t.Fatalf("failed to get Port from state")
|
|
}
|
|
if p != "40212" {
|
|
t.Fatalf("wrong port from state")
|
|
}
|
|
}
|
|
|
|
func BenchmarkStateDo(b *testing.B) {
|
|
st := testState()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
st.Do()
|
|
}
|
|
}
|
|
|
|
func BenchmarkStateSize(b *testing.B) {
|
|
st := testState()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
st.Size()
|
|
}
|
|
}
|
|
|
|
func testState() State {
|
|
m := new(dns.Msg)
|
|
m.SetQuestion("example.com.", dns.TypeA)
|
|
m.SetEdns0(4097, true)
|
|
return State{W: &test.ResponseWriter{}, Req: m}
|
|
}
|
|
|
|
/*
|
|
func TestHeader(t *testing.T) {
|
|
state := getContextOrFail(t)
|
|
|
|
headerKey, headerVal := "Header1", "HeaderVal1"
|
|
state.Req.Header.Add(headerKey, headerVal)
|
|
|
|
actualHeaderVal := state.Header(headerKey)
|
|
if actualHeaderVal != headerVal {
|
|
t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
|
|
}
|
|
|
|
missingHeaderVal := state.Header("not-existing")
|
|
if missingHeaderVal != "" {
|
|
t.Errorf("Expected empty header value, found %s", missingHeaderVal)
|
|
}
|
|
}
|
|
|
|
func TestIP(t *testing.T) {
|
|
state := getContextOrFail(t)
|
|
|
|
tests := []struct {
|
|
inputRemoteAddr string
|
|
expectedIP string
|
|
}{
|
|
// Test 0 - ipv4 with port
|
|
{"1.1.1.1:1111", "1.1.1.1"},
|
|
// Test 1 - ipv4 without port
|
|
{"1.1.1.1", "1.1.1.1"},
|
|
// Test 2 - ipv6 with port
|
|
{"[::1]:11", "::1"},
|
|
// Test 3 - ipv6 without port and brackets
|
|
{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
|
|
// Test 4 - ipv6 with zone and port
|
|
{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
testPrefix := getTestPrefix(i)
|
|
|
|
state.Req.RemoteAddr = test.inputRemoteAddr
|
|
actualIP := state.IP()
|
|
|
|
if actualIP != test.expectedIP {
|
|
t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestURL(t *testing.T) {
|
|
state := getContextOrFail(t)
|
|
|
|
inputURL := "http://localhost"
|
|
state.Req.RequestURI = inputURL
|
|
|
|
if inputURL != state.URI() {
|
|
t.Errorf("Expected url %s, found %s", inputURL, state.URI())
|
|
}
|
|
}
|
|
|
|
func TestHost(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expectedHost string
|
|
shouldErr bool
|
|
}{
|
|
{
|
|
input: "localhost:123",
|
|
expectedHost: "localhost",
|
|
shouldErr: false,
|
|
},
|
|
{
|
|
input: "localhost",
|
|
expectedHost: "localhost",
|
|
shouldErr: false,
|
|
},
|
|
{
|
|
input: "[::]",
|
|
expectedHost: "",
|
|
shouldErr: true,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
|
|
}
|
|
}
|
|
|
|
func TestPort(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expectedPort string
|
|
shouldErr bool
|
|
}{
|
|
{
|
|
input: "localhost:123",
|
|
expectedPort: "123",
|
|
shouldErr: false,
|
|
},
|
|
{
|
|
input: "localhost",
|
|
expectedPort: "80", // assuming 80 is the default port
|
|
shouldErr: false,
|
|
},
|
|
{
|
|
input: ":8080",
|
|
expectedPort: "8080",
|
|
shouldErr: false,
|
|
},
|
|
{
|
|
input: "[::]",
|
|
expectedPort: "",
|
|
shouldErr: true,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
|
|
}
|
|
}
|
|
|
|
func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
|
|
state := getContextOrFail(t)
|
|
|
|
state.Req.Host = input
|
|
var actualResult, testedObject string
|
|
var err error
|
|
|
|
if isTestingHost {
|
|
actualResult, err = state.Host()
|
|
testedObject = "host"
|
|
} else {
|
|
actualResult, err = state.Port()
|
|
testedObject = "port"
|
|
}
|
|
|
|
if shouldErr && err == nil {
|
|
t.Errorf("Expected error, found nil!")
|
|
return
|
|
}
|
|
|
|
if !shouldErr && err != nil {
|
|
t.Errorf("Expected no error, found %s", err)
|
|
return
|
|
}
|
|
|
|
if actualResult != expectedResult {
|
|
t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
|
|
}
|
|
}
|
|
|
|
func TestPathMatches(t *testing.T) {
|
|
state := getContextOrFail(t)
|
|
|
|
tests := []struct {
|
|
urlStr string
|
|
pattern string
|
|
shouldMatch bool
|
|
}{
|
|
// Test 0
|
|
{
|
|
urlStr: "http://localhost/",
|
|
pattern: "",
|
|
shouldMatch: true,
|
|
},
|
|
// Test 1
|
|
{
|
|
urlStr: "http://localhost",
|
|
pattern: "",
|
|
shouldMatch: true,
|
|
},
|
|
// Test 1
|
|
{
|
|
urlStr: "http://localhost/",
|
|
pattern: "/",
|
|
shouldMatch: true,
|
|
},
|
|
// Test 3
|
|
{
|
|
urlStr: "http://localhost/?param=val",
|
|
pattern: "/",
|
|
shouldMatch: true,
|
|
},
|
|
// Test 4
|
|
{
|
|
urlStr: "http://localhost/dir1/dir2",
|
|
pattern: "/dir2",
|
|
shouldMatch: false,
|
|
},
|
|
// Test 5
|
|
{
|
|
urlStr: "http://localhost/dir1/dir2",
|
|
pattern: "/dir1",
|
|
shouldMatch: true,
|
|
},
|
|
// Test 6
|
|
{
|
|
urlStr: "http://localhost:444/dir1/dir2",
|
|
pattern: "/dir1",
|
|
shouldMatch: true,
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
testPrefix := getTestPrefix(i)
|
|
var err error
|
|
state.Req.URL, err = url.Parse(test.urlStr)
|
|
if err != nil {
|
|
t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
|
|
}
|
|
|
|
matches := state.PathMatches(test.pattern)
|
|
if matches != test.shouldMatch {
|
|
t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
|
|
}
|
|
}
|
|
}
|
|
|
|
func initTestContext() (Context, error) {
|
|
body := bytes.NewBufferString("request body")
|
|
request, err := http.NewRequest("GET", "https://localhost", body)
|
|
if err != nil {
|
|
return Context{}, err
|
|
}
|
|
|
|
return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
|
|
}
|
|
|
|
|
|
func getTestPrefix(testN int) string {
|
|
return fmt.Sprintf("Test [%d]: ", testN)
|
|
}
|
|
*/
|