Use context.Context
Rename the old Context to State and use context.Context in the middleware for intra-middleware communication and more.
This commit is contained in:
parent
523cc0a0fd
commit
f907311cdf
27 changed files with 358 additions and 919 deletions
|
@ -1,16 +1,6 @@
|
||||||
package https
|
package https
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware/redirect"
|
|
||||||
"github.com/miekg/coredns/server"
|
|
||||||
"github.com/xenolf/lego/acme"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestHostQualifies(t *testing.T) {
|
func TestHostQualifies(t *testing.T) {
|
||||||
for i, test := range []struct {
|
for i, test := range []struct {
|
||||||
host string
|
host string
|
||||||
|
@ -330,3 +320,4 @@ func TestMarkQualified(t *testing.T) {
|
||||||
t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count)
|
t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/miekg/coredns/core/parse"
|
"github.com/miekg/coredns/core/parse"
|
||||||
"github.com/miekg/coredns/middleware"
|
"github.com/miekg/coredns/middleware"
|
||||||
"github.com/miekg/coredns/server"
|
"github.com/miekg/coredns/server"
|
||||||
|
@ -70,7 +72,7 @@ func NewTestController(input string) *Controller {
|
||||||
//
|
//
|
||||||
// Used primarily for testing but needs to be exported so
|
// Used primarily for testing but needs to be exported so
|
||||||
// add-ons can use this as a convenience.
|
// add-ons can use this as a convenience.
|
||||||
var EmptyNext = middleware.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
var EmptyNext = middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,6 @@
|
||||||
package setup
|
package setup
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware"
|
|
||||||
"github.com/miekg/coredns/middleware/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestErrors(t *testing.T) {
|
func TestErrors(t *testing.T) {
|
||||||
c := NewTestController(`errors`)
|
c := NewTestController(`errors`)
|
||||||
mid, err := Errors(c)
|
mid, err := Errors(c)
|
||||||
|
@ -154,5 +148,5 @@ func TestErrorsParse(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -1,13 +1,6 @@
|
||||||
package setup
|
package setup
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"fmt"
|
|
||||||
"regexp"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware/rewrite"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRewrite(t *testing.T) {
|
func TestRewrite(t *testing.T) {
|
||||||
c := NewTestController(`rewrite /from /to`)
|
c := NewTestController(`rewrite /from /to`)
|
||||||
|
|
||||||
|
@ -237,5 +230,5 @@ func TestRewriteParse(t *testing.T) {
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -1,613 +0,0 @@
|
||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestInclude(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
inputFilename := "test_file"
|
|
||||||
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
|
|
||||||
defer func() {
|
|
||||||
err := os.Remove(absInFilePath)
|
|
||||||
if err != nil && !os.IsNotExist(err) {
|
|
||||||
t.Fatalf("Failed to clean test file!")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
fileContent string
|
|
||||||
expectedContent string
|
|
||||||
shouldErr bool
|
|
||||||
expectedErrorContent string
|
|
||||||
}{
|
|
||||||
// Test 0 - all good
|
|
||||||
{
|
|
||||||
fileContent: `str1 {{ .Root }} str2`,
|
|
||||||
expectedContent: fmt.Sprintf("str1 %s str2", context.Root),
|
|
||||||
shouldErr: false,
|
|
||||||
expectedErrorContent: "",
|
|
||||||
},
|
|
||||||
// Test 1 - failure on template.Parse
|
|
||||||
{
|
|
||||||
fileContent: `str1 {{ .Root } str2`,
|
|
||||||
expectedContent: "",
|
|
||||||
shouldErr: true,
|
|
||||||
expectedErrorContent: `unexpected "}" in operand`,
|
|
||||||
},
|
|
||||||
// Test 3 - failure on template.Execute
|
|
||||||
{
|
|
||||||
fileContent: `str1 {{ .InvalidField }} str2`,
|
|
||||||
expectedContent: "",
|
|
||||||
shouldErr: true,
|
|
||||||
expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, test := range tests {
|
|
||||||
testPrefix := getTestPrefix(i)
|
|
||||||
|
|
||||||
// WriteFile truncates the contentt
|
|
||||||
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
content, err := context.Include(inputFilename)
|
|
||||||
if err != nil {
|
|
||||||
if !test.shouldErr {
|
|
||||||
t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error())
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), test.expectedErrorContent) {
|
|
||||||
t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == nil && test.shouldErr {
|
|
||||||
t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename)
|
|
||||||
}
|
|
||||||
|
|
||||||
if content != test.expectedContent {
|
|
||||||
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIncludeNotExisting(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
_, err := context.Include("not_existing")
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("Expected error but found nil!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMarkdown(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
inputFilename := "test_file"
|
|
||||||
absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
|
|
||||||
defer func() {
|
|
||||||
err := os.Remove(absInFilePath)
|
|
||||||
if err != nil && !os.IsNotExist(err) {
|
|
||||||
t.Fatalf("Failed to clean test file!")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
fileContent string
|
|
||||||
expectedContent string
|
|
||||||
}{
|
|
||||||
// Test 0 - test parsing of markdown
|
|
||||||
{
|
|
||||||
fileContent: "* str1\n* str2\n",
|
|
||||||
expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, test := range tests {
|
|
||||||
testPrefix := getTestPrefix(i)
|
|
||||||
|
|
||||||
// WriteFile truncates the contentt
|
|
||||||
err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
content, _ := context.Markdown(inputFilename)
|
|
||||||
if content != test.expectedContent {
|
|
||||||
t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCookie(t *testing.T) {
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
cookie *http.Cookie
|
|
||||||
cookieName string
|
|
||||||
expectedValue string
|
|
||||||
}{
|
|
||||||
// Test 0 - happy path
|
|
||||||
{
|
|
||||||
cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
|
|
||||||
cookieName: "cookieName",
|
|
||||||
expectedValue: "cookieValue",
|
|
||||||
},
|
|
||||||
// Test 1 - try to get a non-existing cookie
|
|
||||||
{
|
|
||||||
cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
|
|
||||||
cookieName: "notExisting",
|
|
||||||
expectedValue: "",
|
|
||||||
},
|
|
||||||
// Test 2 - partial name match
|
|
||||||
{
|
|
||||||
cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"},
|
|
||||||
cookieName: "cook",
|
|
||||||
expectedValue: "",
|
|
||||||
},
|
|
||||||
// Test 3 - cookie with optional fields
|
|
||||||
{
|
|
||||||
cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120},
|
|
||||||
cookieName: "cookie",
|
|
||||||
expectedValue: "cookieValue",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, test := range tests {
|
|
||||||
testPrefix := getTestPrefix(i)
|
|
||||||
|
|
||||||
// reinitialize the context for each test
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
context.Req.AddCookie(test.cookie)
|
|
||||||
|
|
||||||
actualCookieVal := context.Cookie(test.cookieName)
|
|
||||||
|
|
||||||
if actualCookieVal != test.expectedValue {
|
|
||||||
t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCookieMultipleCookies(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
cookieNameBase, cookieValueBase := "cookieName", "cookieValue"
|
|
||||||
|
|
||||||
// make sure that there's no state and multiple requests for different cookies return the correct result
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)})
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i)
|
|
||||||
actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
|
|
||||||
if actualCookieVal != expectedCookieVal {
|
|
||||||
t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHeader(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
headerKey, headerVal := "Header1", "HeaderVal1"
|
|
||||||
context.Req.Header.Add(headerKey, headerVal)
|
|
||||||
|
|
||||||
actualHeaderVal := context.Header(headerKey)
|
|
||||||
if actualHeaderVal != headerVal {
|
|
||||||
t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
|
|
||||||
}
|
|
||||||
|
|
||||||
missingHeaderVal := context.Header("not-existing")
|
|
||||||
if missingHeaderVal != "" {
|
|
||||||
t.Errorf("Expected empty header value, found %s", missingHeaderVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIP(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
inputRemoteAddr string
|
|
||||||
expectedIP string
|
|
||||||
}{
|
|
||||||
// Test 0 - ipv4 with port
|
|
||||||
{"1.1.1.1:1111", "1.1.1.1"},
|
|
||||||
// Test 1 - ipv4 without port
|
|
||||||
{"1.1.1.1", "1.1.1.1"},
|
|
||||||
// Test 2 - ipv6 with port
|
|
||||||
{"[::1]:11", "::1"},
|
|
||||||
// Test 3 - ipv6 without port and brackets
|
|
||||||
{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
|
|
||||||
// Test 4 - ipv6 with zone and port
|
|
||||||
{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, test := range tests {
|
|
||||||
testPrefix := getTestPrefix(i)
|
|
||||||
|
|
||||||
context.Req.RemoteAddr = test.inputRemoteAddr
|
|
||||||
actualIP := context.IP()
|
|
||||||
|
|
||||||
if actualIP != test.expectedIP {
|
|
||||||
t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestURL(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
inputURL := "http://localhost"
|
|
||||||
context.Req.RequestURI = inputURL
|
|
||||||
|
|
||||||
if inputURL != context.URI() {
|
|
||||||
t.Errorf("Expected url %s, found %s", inputURL, context.URI())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHost(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
expectedHost string
|
|
||||||
shouldErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
input: "localhost:123",
|
|
||||||
expectedHost: "localhost",
|
|
||||||
shouldErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: "localhost",
|
|
||||||
expectedHost: "localhost",
|
|
||||||
shouldErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: "[::]",
|
|
||||||
expectedHost: "",
|
|
||||||
shouldErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPort(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
expectedPort string
|
|
||||||
shouldErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
input: "localhost:123",
|
|
||||||
expectedPort: "123",
|
|
||||||
shouldErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: "localhost",
|
|
||||||
expectedPort: "80", // assuming 80 is the default port
|
|
||||||
shouldErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: ":8080",
|
|
||||||
expectedPort: "8080",
|
|
||||||
shouldErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
input: "[::]",
|
|
||||||
expectedPort: "",
|
|
||||||
shouldErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
context.Req.Host = input
|
|
||||||
var actualResult, testedObject string
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if isTestingHost {
|
|
||||||
actualResult, err = context.Host()
|
|
||||||
testedObject = "host"
|
|
||||||
} else {
|
|
||||||
actualResult, err = context.Port()
|
|
||||||
testedObject = "port"
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldErr && err == nil {
|
|
||||||
t.Errorf("Expected error, found nil!")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !shouldErr && err != nil {
|
|
||||||
t.Errorf("Expected no error, found %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if actualResult != expectedResult {
|
|
||||||
t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMethod(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
method := "POST"
|
|
||||||
context.Req.Method = method
|
|
||||||
|
|
||||||
if method != context.Method() {
|
|
||||||
t.Errorf("Expected method %s, found %s", method, context.Method())
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPathMatches(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
urlStr string
|
|
||||||
pattern string
|
|
||||||
shouldMatch bool
|
|
||||||
}{
|
|
||||||
// Test 0
|
|
||||||
{
|
|
||||||
urlStr: "http://localhost/",
|
|
||||||
pattern: "",
|
|
||||||
shouldMatch: true,
|
|
||||||
},
|
|
||||||
// Test 1
|
|
||||||
{
|
|
||||||
urlStr: "http://localhost",
|
|
||||||
pattern: "",
|
|
||||||
shouldMatch: true,
|
|
||||||
},
|
|
||||||
// Test 1
|
|
||||||
{
|
|
||||||
urlStr: "http://localhost/",
|
|
||||||
pattern: "/",
|
|
||||||
shouldMatch: true,
|
|
||||||
},
|
|
||||||
// Test 3
|
|
||||||
{
|
|
||||||
urlStr: "http://localhost/?param=val",
|
|
||||||
pattern: "/",
|
|
||||||
shouldMatch: true,
|
|
||||||
},
|
|
||||||
// Test 4
|
|
||||||
{
|
|
||||||
urlStr: "http://localhost/dir1/dir2",
|
|
||||||
pattern: "/dir2",
|
|
||||||
shouldMatch: false,
|
|
||||||
},
|
|
||||||
// Test 5
|
|
||||||
{
|
|
||||||
urlStr: "http://localhost/dir1/dir2",
|
|
||||||
pattern: "/dir1",
|
|
||||||
shouldMatch: true,
|
|
||||||
},
|
|
||||||
// Test 6
|
|
||||||
{
|
|
||||||
urlStr: "http://localhost:444/dir1/dir2",
|
|
||||||
pattern: "/dir1",
|
|
||||||
shouldMatch: true,
|
|
||||||
},
|
|
||||||
// Test 7
|
|
||||||
{
|
|
||||||
urlStr: "http://localhost/dir1/dir2",
|
|
||||||
pattern: "*/dir2",
|
|
||||||
shouldMatch: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, test := range tests {
|
|
||||||
testPrefix := getTestPrefix(i)
|
|
||||||
var err error
|
|
||||||
context.Req.URL, err = url.Parse(test.urlStr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
matches := context.PathMatches(test.pattern)
|
|
||||||
if matches != test.shouldMatch {
|
|
||||||
t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTruncate(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
tests := []struct {
|
|
||||||
inputString string
|
|
||||||
inputLength int
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
// Test 0 - small length
|
|
||||||
{
|
|
||||||
inputString: "string",
|
|
||||||
inputLength: 1,
|
|
||||||
expected: "s",
|
|
||||||
},
|
|
||||||
// Test 1 - exact length
|
|
||||||
{
|
|
||||||
inputString: "string",
|
|
||||||
inputLength: 6,
|
|
||||||
expected: "string",
|
|
||||||
},
|
|
||||||
// Test 2 - bigger length
|
|
||||||
{
|
|
||||||
inputString: "string",
|
|
||||||
inputLength: 10,
|
|
||||||
expected: "string",
|
|
||||||
},
|
|
||||||
// Test 3 - zero length
|
|
||||||
{
|
|
||||||
inputString: "string",
|
|
||||||
inputLength: 0,
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
// Test 4 - negative, smaller length
|
|
||||||
{
|
|
||||||
inputString: "string",
|
|
||||||
inputLength: -5,
|
|
||||||
expected: "tring",
|
|
||||||
},
|
|
||||||
// Test 5 - negative, exact length
|
|
||||||
{
|
|
||||||
inputString: "string",
|
|
||||||
inputLength: -6,
|
|
||||||
expected: "string",
|
|
||||||
},
|
|
||||||
// Test 6 - negative, bigger length
|
|
||||||
{
|
|
||||||
inputString: "string",
|
|
||||||
inputLength: -7,
|
|
||||||
expected: "string",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, test := range tests {
|
|
||||||
actual := context.Truncate(test.inputString, test.inputLength)
|
|
||||||
if actual != test.expected {
|
|
||||||
t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStripHTML(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
// Test 0 - no tags
|
|
||||||
{
|
|
||||||
input: `h1`,
|
|
||||||
expected: `h1`,
|
|
||||||
},
|
|
||||||
// Test 1 - happy path
|
|
||||||
{
|
|
||||||
input: `<h1>h1</h1>`,
|
|
||||||
expected: `h1`,
|
|
||||||
},
|
|
||||||
// Test 2 - tag in quotes
|
|
||||||
{
|
|
||||||
input: `<h1">">h1</h1>`,
|
|
||||||
expected: `h1`,
|
|
||||||
},
|
|
||||||
// Test 3 - multiple tags
|
|
||||||
{
|
|
||||||
input: `<h1><b>h1</b></h1>`,
|
|
||||||
expected: `h1`,
|
|
||||||
},
|
|
||||||
// Test 4 - tags not closed
|
|
||||||
{
|
|
||||||
input: `<h1`,
|
|
||||||
expected: `<h1`,
|
|
||||||
},
|
|
||||||
// Test 5 - false start
|
|
||||||
{
|
|
||||||
input: `<h1<b>hi`,
|
|
||||||
expected: `<h1hi`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, test := range tests {
|
|
||||||
actual := context.StripHTML(test.input)
|
|
||||||
if actual != test.expected {
|
|
||||||
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStripExt(t *testing.T) {
|
|
||||||
context := getContextOrFail(t)
|
|
||||||
tests := []struct {
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
// Test 0 - empty input
|
|
||||||
{
|
|
||||||
input: "",
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
// Test 1 - relative file with ext
|
|
||||||
{
|
|
||||||
input: "file.ext",
|
|
||||||
expected: "file",
|
|
||||||
},
|
|
||||||
// Test 2 - relative file without ext
|
|
||||||
{
|
|
||||||
input: "file",
|
|
||||||
expected: "file",
|
|
||||||
},
|
|
||||||
// Test 3 - absolute file without ext
|
|
||||||
{
|
|
||||||
input: "/file",
|
|
||||||
expected: "/file",
|
|
||||||
},
|
|
||||||
// Test 4 - absolute file with ext
|
|
||||||
{
|
|
||||||
input: "/file.ext",
|
|
||||||
expected: "/file",
|
|
||||||
},
|
|
||||||
// Test 5 - with ext but ends with /
|
|
||||||
{
|
|
||||||
input: "/dir.ext/",
|
|
||||||
expected: "/dir.ext/",
|
|
||||||
},
|
|
||||||
// Test 6 - file with ext under dir with ext
|
|
||||||
{
|
|
||||||
input: "/dir.ext/file.ext",
|
|
||||||
expected: "/dir.ext/file",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, test := range tests {
|
|
||||||
actual := context.StripExt(test.input)
|
|
||||||
if actual != test.expected {
|
|
||||||
t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func initTestContext() (Context, error) {
|
|
||||||
body := bytes.NewBufferString("request body")
|
|
||||||
request, err := http.NewRequest("GET", "https://localhost", body)
|
|
||||||
if err != nil {
|
|
||||||
return Context{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getContextOrFail(t *testing.T) Context {
|
|
||||||
context, err := initTestContext()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to prepare test context")
|
|
||||||
}
|
|
||||||
return context
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTestPrefix(testN int) string {
|
|
||||||
return fmt.Sprintf("Test [%d]: ", testN)
|
|
||||||
}
|
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware"
|
"github.com/miekg/coredns/middleware"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
@ -21,10 +23,10 @@ type ErrorHandler struct {
|
||||||
Debug bool // if true, errors are written out to client rather than to a log
|
Debug bool // if true, errors are written out to client rather than to a log
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (h ErrorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
defer h.recovery(w, r)
|
defer h.recovery(w, r)
|
||||||
|
|
||||||
rcode, err := h.Next.ServeDNS(w, r)
|
rcode, err := h.Next.ServeDNS(ctx, w, r)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err)
|
errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err)
|
||||||
|
|
|
@ -1,21 +1,6 @@
|
||||||
package errors
|
package errors
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestErrors(t *testing.T) {
|
func TestErrors(t *testing.T) {
|
||||||
// create a temporary page
|
// create a temporary page
|
||||||
path := filepath.Join(os.TempDir(), "errors_test.html")
|
path := filepath.Join(os.TempDir(), "errors_test.html")
|
||||||
|
@ -166,3 +151,4 @@ func genErrorHandler(status int, err error, body string) middleware.Handler {
|
||||||
return status, err
|
return status, err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -8,6 +8,8 @@ package file
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware"
|
"github.com/miekg/coredns/middleware"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
@ -26,29 +28,29 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f File) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
context := middleware.Context{W: w, Req: r}
|
state := middleware.State{W: w, Req: r}
|
||||||
qname := context.Name()
|
qname := state.Name()
|
||||||
zone := middleware.Zones(f.Zones.Names).Matches(qname)
|
zone := middleware.Zones(f.Zones.Names).Matches(qname)
|
||||||
if zone == "" {
|
if zone == "" {
|
||||||
return f.Next.ServeDNS(w, r)
|
return f.Next.ServeDNS(ctx, w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
names, nodata := f.Zones.Z[zone].lookup(qname, context.QType())
|
names, nodata := f.Zones.Z[zone].lookup(qname, state.QType())
|
||||||
var answer *dns.Msg
|
var answer *dns.Msg
|
||||||
switch {
|
switch {
|
||||||
case nodata:
|
case nodata:
|
||||||
answer = context.AnswerMessage()
|
answer = state.AnswerMessage()
|
||||||
answer.Ns = names
|
answer.Ns = names
|
||||||
case len(names) == 0:
|
case len(names) == 0:
|
||||||
answer = context.AnswerMessage()
|
answer = state.AnswerMessage()
|
||||||
answer.Ns = names
|
answer.Ns = names
|
||||||
answer.Rcode = dns.RcodeNameError
|
answer.Rcode = dns.RcodeNameError
|
||||||
case len(names) > 0:
|
case len(names) > 0:
|
||||||
answer = context.AnswerMessage()
|
answer = state.AnswerMessage()
|
||||||
answer.Answer = names
|
answer.Answer = names
|
||||||
default:
|
default:
|
||||||
answer = context.ErrorMessage(dns.RcodeServerFailure)
|
answer = state.ErrorMessage(dns.RcodeServerFailure)
|
||||||
}
|
}
|
||||||
// Check return size, etc. TODO(miek)
|
// Check return size, etc. TODO(miek)
|
||||||
w.WriteMsg(answer)
|
w.WriteMsg(answer)
|
||||||
|
|
|
@ -1,15 +1,6 @@
|
||||||
package file
|
package file
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
var testDir = filepath.Join(os.TempDir(), "caddy_testdir")
|
var testDir = filepath.Join(os.TempDir(), "caddy_testdir")
|
||||||
var ErrCustom = errors.New("Custom Error")
|
var ErrCustom = errors.New("Custom Error")
|
||||||
|
|
||||||
|
@ -323,3 +314,4 @@ func TestServeHTTPFailingStat(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -4,6 +4,8 @@ package log
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware"
|
"github.com/miekg/coredns/middleware"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
@ -15,7 +17,7 @@ type Logger struct {
|
||||||
ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler
|
ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
for _, rule := range l.Rules {
|
for _, rule := range l.Rules {
|
||||||
/*
|
/*
|
||||||
if middleware.Path(r.URL.Path).Matches(rule.PathScope) {
|
if middleware.Path(r.URL.Path).Matches(rule.PathScope) {
|
||||||
|
@ -40,7 +42,7 @@ func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
*/
|
*/
|
||||||
rule = rule
|
rule = rule
|
||||||
}
|
}
|
||||||
return l.Next.ServeDNS(w, r)
|
return l.Next.ServeDNS(ctx, w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rule configures the logging middleware.
|
// Rule configures the logging middleware.
|
||||||
|
|
|
@ -1,17 +1,9 @@
|
||||||
package log
|
package log
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"bytes"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type erroringMiddleware struct{}
|
type erroringMiddleware struct{}
|
||||||
|
|
||||||
func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
func (erroringMiddleware) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
return http.StatusNotFound, nil
|
return http.StatusNotFound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,3 +38,4 @@ func TestLoggedStatus(t *testing.T) {
|
||||||
t.Error("Expected 404 to be logged. Logged string -", logged)
|
t.Error("Expected 404 to be logged. Logged string -", logged)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
@ -32,18 +33,18 @@ type (
|
||||||
// Otherwise, return values should be propagated down the middleware
|
// Otherwise, return values should be propagated down the middleware
|
||||||
// chain by returning them unchanged.
|
// chain by returning them unchanged.
|
||||||
Handler interface {
|
Handler interface {
|
||||||
ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error)
|
ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlerFunc is a convenience type like dns.HandlerFunc, except
|
// HandlerFunc is a convenience type like dns.HandlerFunc, except
|
||||||
// ServeDNS returns an rcode and an error. See Handler
|
// ServeDNS returns an rcode and an error. See Handler
|
||||||
// documentation for more information.
|
// documentation for more information.
|
||||||
HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error)
|
HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
|
||||||
)
|
)
|
||||||
|
|
||||||
// ServeDNS implements the Handler interface.
|
// ServeDNS implements the Handler interface.
|
||||||
func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
return f(w, r)
|
return f(ctx, w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IndexFile looks for a file in /root/fpath/indexFile for each string
|
// IndexFile looks for a file in /root/fpath/indexFile for each string
|
||||||
|
|
|
@ -1,108 +1 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIndexfile(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
rootDir http.FileSystem
|
|
||||||
fpath string
|
|
||||||
indexFiles []string
|
|
||||||
shouldErr bool
|
|
||||||
expectedFilePath string //retun value
|
|
||||||
expectedBoolValue bool //return value
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
http.Dir("./templates/testdata"),
|
|
||||||
"/images/",
|
|
||||||
[]string{"img.htm"},
|
|
||||||
false,
|
|
||||||
"/images/img.htm",
|
|
||||||
true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for i, test := range tests {
|
|
||||||
actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles)
|
|
||||||
if actualBoolValue == true && test.shouldErr {
|
|
||||||
t.Errorf("Test %d didn't error, but it should have", i)
|
|
||||||
} else if actualBoolValue != true && !test.shouldErr {
|
|
||||||
t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist")
|
|
||||||
}
|
|
||||||
if actualFilePath != test.expectedFilePath {
|
|
||||||
t.Fatalf("Test %d expected returned filepath to be %s, but got %s ",
|
|
||||||
i, test.expectedFilePath, actualFilePath)
|
|
||||||
|
|
||||||
}
|
|
||||||
if actualBoolValue != test.expectedBoolValue {
|
|
||||||
t.Fatalf("Test %d expected returned bool value to be %v, but got %v ",
|
|
||||||
i, test.expectedBoolValue, actualBoolValue)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetLastModified(t *testing.T) {
|
|
||||||
nowTime := time.Now()
|
|
||||||
|
|
||||||
// ovewrite the function to return reliable time
|
|
||||||
originalGetCurrentTimeFunc := currentTime
|
|
||||||
currentTime = func() time.Time {
|
|
||||||
return nowTime
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
currentTime = originalGetCurrentTimeFunc
|
|
||||||
}()
|
|
||||||
|
|
||||||
pastTime := nowTime.Truncate(1 * time.Hour)
|
|
||||||
futureTime := nowTime.Add(1 * time.Hour)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
inputModTime time.Time
|
|
||||||
expectedIsHeaderSet bool
|
|
||||||
expectedLastModified string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
inputModTime: pastTime,
|
|
||||||
expectedIsHeaderSet: true,
|
|
||||||
expectedLastModified: pastTime.UTC().Format(http.TimeFormat),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
inputModTime: nowTime,
|
|
||||||
expectedIsHeaderSet: true,
|
|
||||||
expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
inputModTime: futureTime,
|
|
||||||
expectedIsHeaderSet: true,
|
|
||||||
expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
inputModTime: time.Time{},
|
|
||||||
expectedIsHeaderSet: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, test := range tests {
|
|
||||||
responseRecorder := httptest.NewRecorder()
|
|
||||||
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
|
|
||||||
SetLastModifiedHeader(responseRecorder, test.inputModTime)
|
|
||||||
actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified")
|
|
||||||
|
|
||||||
if test.expectedIsHeaderSet && actualLastModifiedHeader == "" {
|
|
||||||
t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" {
|
|
||||||
t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
if test.expectedLastModified != actualLastModifiedHeader {
|
|
||||||
t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -4,15 +4,17 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware"
|
"github.com/miekg/coredns/middleware"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
context := middleware.Context{W: w, Req: r}
|
state := middleware.State{W: w, Req: r}
|
||||||
|
|
||||||
qname := context.Name()
|
qname := state.Name()
|
||||||
qtype := context.Type()
|
qtype := state.Type()
|
||||||
zone := middleware.Zones(m.ZoneNames).Matches(qname)
|
zone := middleware.Zones(m.ZoneNames).Matches(qname)
|
||||||
if zone == "" {
|
if zone == "" {
|
||||||
zone = "."
|
zone = "."
|
||||||
|
@ -20,7 +22,7 @@ func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
|
||||||
// Record response to get status code and size of the reply.
|
// Record response to get status code and size of the reply.
|
||||||
rw := middleware.NewResponseRecorder(w)
|
rw := middleware.NewResponseRecorder(w)
|
||||||
status, err := m.Next.ServeDNS(rw, r)
|
status, err := m.Next.ServeDNS(ctx, rw, r)
|
||||||
|
|
||||||
requestCount.WithLabelValues(zone, qtype).Inc()
|
requestCount.WithLabelValues(zone, qtype).Inc()
|
||||||
requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second))
|
requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second))
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware"
|
"github.com/miekg/coredns/middleware"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
@ -67,7 +69,7 @@ func (uh *UpstreamHost) Down() bool {
|
||||||
var tryDuration = 60 * time.Second
|
var tryDuration = 60 * time.Second
|
||||||
|
|
||||||
// ServeDNS satisfies the middleware.Handler interface.
|
// ServeDNS satisfies the middleware.Handler interface.
|
||||||
func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (p Proxy) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
for _, upstream := range p.Upstreams {
|
for _, upstream := range p.Upstreams {
|
||||||
// allowed bla bla bla TODO(miek): fix full proxy spec from caddy
|
// allowed bla bla bla TODO(miek): fix full proxy spec from caddy
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
@ -100,7 +102,7 @@ func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
}
|
}
|
||||||
return dns.RcodeServerFailure, errUnreachable
|
return dns.RcodeServerFailure, errUnreachable
|
||||||
}
|
}
|
||||||
return p.Next.ServeDNS(w, r)
|
return p.Next.ServeDNS(ctx, w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Clients() Client {
|
func Clients() Client {
|
||||||
|
|
|
@ -1,26 +1,6 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
tryDuration = 50 * time.Millisecond // prevent tests from hanging
|
tryDuration = 50 * time.Millisecond // prevent tests from hanging
|
||||||
}
|
}
|
||||||
|
@ -315,3 +295,4 @@ func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
func (c *fakeConn) Close() error { return nil }
|
func (c *fakeConn) Close() error { return nil }
|
||||||
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
|
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
|
||||||
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
|
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
|
||||||
|
*/
|
||||||
|
|
|
@ -12,15 +12,15 @@ type ReverseProxy struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error {
|
func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error {
|
||||||
// TODO(miek): use extra!
|
// TODO(miek): use extra to EDNS0.
|
||||||
var (
|
var (
|
||||||
reply *dns.Msg
|
reply *dns.Msg
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
context := middleware.Context{W: w, Req: r}
|
state := middleware.State{W: w, Req: r}
|
||||||
|
|
||||||
// tls+tcp ?
|
// tls+tcp ?
|
||||||
if context.Proto() == "tcp" {
|
if state.Proto() == "tcp" {
|
||||||
reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
|
reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
|
||||||
} else {
|
} else {
|
||||||
reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
|
reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
|
||||||
|
|
|
@ -1,11 +1,6 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewResponseRecorder(t *testing.T) {
|
func TestNewResponseRecorder(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
recordRequest := NewResponseRecorder(w)
|
recordRequest := NewResponseRecorder(w)
|
||||||
|
@ -30,3 +25,4 @@ func TestWrite(t *testing.T) {
|
||||||
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
|
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -20,6 +20,8 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware"
|
"github.com/miekg/coredns/middleware"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
@ -28,15 +30,15 @@ type Reflect struct {
|
||||||
Next middleware.Handler
|
Next middleware.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (rl Reflect) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
context := middleware.Context{Req: r, W: w}
|
state := middleware.State{Req: r, W: w}
|
||||||
|
|
||||||
class := r.Question[0].Qclass
|
class := r.Question[0].Qclass
|
||||||
qname := r.Question[0].Name
|
qname := r.Question[0].Name
|
||||||
i, ok := dns.NextLabel(qname, 0)
|
i, ok := dns.NextLabel(qname, 0)
|
||||||
|
|
||||||
if strings.ToLower(qname[:i]) != who || ok {
|
if strings.ToLower(qname[:i]) != who || ok {
|
||||||
err := context.ErrorMessage(dns.RcodeFormatError)
|
err := state.ErrorMessage(dns.RcodeFormatError)
|
||||||
w.WriteMsg(err)
|
w.WriteMsg(err)
|
||||||
return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError])
|
return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError])
|
||||||
}
|
}
|
||||||
|
@ -46,10 +48,10 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
answer.Compress = true
|
answer.Compress = true
|
||||||
answer.Authoritative = true
|
answer.Authoritative = true
|
||||||
|
|
||||||
ip := context.IP()
|
ip := state.IP()
|
||||||
proto := context.Proto()
|
proto := state.Proto()
|
||||||
port, _ := context.Port()
|
port, _ := state.Port()
|
||||||
family := context.Family()
|
family := state.Family()
|
||||||
var rr dns.RR
|
var rr dns.RR
|
||||||
|
|
||||||
switch family {
|
switch family {
|
||||||
|
@ -67,7 +69,7 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0}
|
t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0}
|
||||||
t.Txt = []string{"Port: " + port + " (" + proto + ")"}
|
t.Txt = []string{"Port: " + port + " (" + proto + ")"}
|
||||||
|
|
||||||
switch context.Type() {
|
switch state.Type() {
|
||||||
case "TXT":
|
case "TXT":
|
||||||
answer.Answer = append(answer.Answer, t)
|
answer.Answer = append(answer.Answer, t)
|
||||||
answer.Extra = append(answer.Extra, rr)
|
answer.Extra = append(answer.Extra, rr)
|
||||||
|
|
|
@ -29,19 +29,19 @@ type replacer struct {
|
||||||
// available. emptyValue should be the string that is used
|
// available. emptyValue should be the string that is used
|
||||||
// in place of empty string (can still be empty string).
|
// in place of empty string (can still be empty string).
|
||||||
func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer {
|
func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer {
|
||||||
context := Context{W: rr, Req: r}
|
state := State{W: rr, Req: r}
|
||||||
rep := replacer{
|
rep := replacer{
|
||||||
replacements: map[string]string{
|
replacements: map[string]string{
|
||||||
"{type}": context.Type(),
|
"{type}": state.Type(),
|
||||||
"{name}": context.Name(),
|
"{name}": state.Name(),
|
||||||
"{class}": context.Class(),
|
"{class}": state.Class(),
|
||||||
"{proto}": context.Proto(),
|
"{proto}": state.Proto(),
|
||||||
"{when}": func() string {
|
"{when}": func() string {
|
||||||
return time.Now().Format(timeFormat)
|
return time.Now().Format(timeFormat)
|
||||||
}(),
|
}(),
|
||||||
"{remote}": context.IP(),
|
"{remote}": state.IP(),
|
||||||
"{port}": func() string {
|
"{port}": func() string {
|
||||||
p, _ := context.Port()
|
p, _ := state.Port()
|
||||||
return p
|
return p
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,12 +1,6 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewReplacer(t *testing.T) {
|
func TestNewReplacer(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
recordRequest := NewResponseRecorder(w)
|
recordRequest := NewResponseRecorder(w)
|
||||||
|
@ -122,3 +116,4 @@ func TestSet(t *testing.T) {
|
||||||
t.Error("Expected variable replacement failed")
|
t.Error("Expected variable replacement failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -1,11 +1,6 @@
|
||||||
package rewrite
|
package rewrite
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConditions(t *testing.T) {
|
func TestConditions(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
condition string
|
condition string
|
||||||
|
@ -104,3 +99,4 @@ func TestConditions(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -5,6 +5,7 @@ package rewrite
|
||||||
import (
|
import (
|
||||||
"github.com/miekg/coredns/middleware"
|
"github.com/miekg/coredns/middleware"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Result is the result of a rewrite
|
// Result is the result of a rewrite
|
||||||
|
@ -27,12 +28,12 @@ type Rewrite struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP implements the middleware.Handler interface.
|
// ServeHTTP implements the middleware.Handler interface.
|
||||||
func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
wr := NewResponseReverter(w, r)
|
wr := NewResponseReverter(w, r)
|
||||||
for _, rule := range rw.Rules {
|
for _, rule := range rw.Rules {
|
||||||
switch result := rule.Rewrite(r); result {
|
switch result := rule.Rewrite(r); result {
|
||||||
case RewriteDone:
|
case RewriteDone:
|
||||||
return rw.Next.ServeDNS(wr, r)
|
return rw.Next.ServeDNS(ctx, wr, r)
|
||||||
case RewriteIgnored:
|
case RewriteIgnored:
|
||||||
break
|
break
|
||||||
case RewriteStatus:
|
case RewriteStatus:
|
||||||
|
@ -42,7 +43,7 @@ func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return rw.Next.ServeDNS(w, r)
|
return rw.Next.ServeDNS(ctx, w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rule describes an internal location rewrite rule.
|
// Rule describes an internal location rewrite rule.
|
||||||
|
|
|
@ -1,15 +1,6 @@
|
||||||
package rewrite
|
package rewrite
|
||||||
|
|
||||||
import (
|
/*
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/coredns/middleware"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRewrite(t *testing.T) {
|
func TestRewrite(t *testing.T) {
|
||||||
rw := Rewrite{
|
rw := Rewrite{
|
||||||
Next: middleware.HandlerFunc(urlPrinter),
|
Next: middleware.HandlerFunc(urlPrinter),
|
||||||
|
@ -157,3 +148,4 @@ func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
fmt.Fprintf(w, r.URL.String())
|
fmt.Fprintf(w, r.URL.String())
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -9,45 +9,44 @@ import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This file contains the context and functions available for
|
// This file contains the state nd functions available for use in the templates.
|
||||||
// use in the templates.
|
|
||||||
|
|
||||||
// Context is the context with which Caddy templates are executed.
|
// State contains some connection state and is useful in middleware.
|
||||||
type Context struct {
|
type State struct {
|
||||||
Root http.FileSystem // TODO(miek): needed
|
Root http.FileSystem // TODO(miek): needed?
|
||||||
Req *dns.Msg
|
Req *dns.Msg
|
||||||
W dns.ResponseWriter
|
W dns.ResponseWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now returns the current timestamp in the specified format.
|
// Now returns the current timestamp in the specified format.
|
||||||
func (c Context) Now(format string) string {
|
func (s State) Now(format string) string {
|
||||||
return time.Now().Format(format)
|
return time.Now().Format(format)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NowDate returns the current date/time that can be used
|
// NowDate returns the current date/time that can be used
|
||||||
// in other time functions.
|
// in other time functions.
|
||||||
func (c Context) NowDate() time.Time {
|
func (s State) NowDate() time.Time {
|
||||||
return time.Now()
|
return time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Header gets the value of a header.
|
// Header gets the value of a header.
|
||||||
func (c Context) Header() *dns.RR_Header {
|
func (s State) Header() *dns.RR_Header {
|
||||||
// TODO(miek)
|
// TODO(miek)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IP gets the (remote) IP address of the client making the request.
|
// IP gets the (remote) IP address of the client making the request.
|
||||||
func (c Context) IP() string {
|
func (s State) IP() string {
|
||||||
ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String())
|
ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.W.RemoteAddr().String()
|
return s.W.RemoteAddr().String()
|
||||||
}
|
}
|
||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
// Post gets the (remote) Port of the client making the request.
|
// Post gets the (remote) Port of the client making the request.
|
||||||
func (c Context) Port() (string, error) {
|
func (s State) Port() (string, error) {
|
||||||
_, port, err := net.SplitHostPort(c.W.RemoteAddr().String())
|
_, port, err := net.SplitHostPort(s.W.RemoteAddr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "0", err
|
return "0", err
|
||||||
}
|
}
|
||||||
|
@ -56,11 +55,11 @@ func (c Context) Port() (string, error) {
|
||||||
|
|
||||||
// Proto gets the protocol used as the transport. This
|
// Proto gets the protocol used as the transport. This
|
||||||
// will be udp or tcp.
|
// will be udp or tcp.
|
||||||
func (c Context) Proto() string {
|
func (s State) Proto() string {
|
||||||
if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok {
|
if _, ok := s.W.RemoteAddr().(*net.UDPAddr); ok {
|
||||||
return "udp"
|
return "udp"
|
||||||
}
|
}
|
||||||
if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok {
|
if _, ok := s.W.RemoteAddr().(*net.TCPAddr); ok {
|
||||||
return "tcp"
|
return "tcp"
|
||||||
}
|
}
|
||||||
return "udp"
|
return "udp"
|
||||||
|
@ -68,9 +67,9 @@ func (c Context) Proto() string {
|
||||||
|
|
||||||
// Family returns the family of the transport.
|
// Family returns the family of the transport.
|
||||||
// 1 for IPv4 and 2 for IPv6.
|
// 1 for IPv4 and 2 for IPv6.
|
||||||
func (c Context) Family() int {
|
func (s State) Family() int {
|
||||||
var a net.IP
|
var a net.IP
|
||||||
ip := c.W.RemoteAddr()
|
ip := s.W.RemoteAddr()
|
||||||
if i, ok := ip.(*net.UDPAddr); ok {
|
if i, ok := ip.(*net.UDPAddr); ok {
|
||||||
a = i.IP
|
a = i.IP
|
||||||
}
|
}
|
||||||
|
@ -85,51 +84,48 @@ func (c Context) Family() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Type returns the type of the question as a string.
|
// Type returns the type of the question as a string.
|
||||||
func (c Context) Type() string {
|
func (s State) Type() string {
|
||||||
return dns.Type(c.Req.Question[0].Qtype).String()
|
return dns.Type(s.Req.Question[0].Qtype).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// QType returns the type of the question as a uint16.
|
// QType returns the type of the question as a uint16.
|
||||||
func (c Context) QType() uint16 {
|
func (s State) QType() uint16 {
|
||||||
return c.Req.Question[0].Qtype
|
return s.Req.Question[0].Qtype
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name returns the name of the question in the request. Note
|
// Name returns the name of the question in the request. Note
|
||||||
// this name will always have a closing dot and will be lower cased.
|
// this name will always have a closing dot and will be lower cased.
|
||||||
func (c Context) Name() string {
|
func (s State) Name() string {
|
||||||
return strings.ToLower(dns.Name(c.Req.Question[0].Name).String())
|
return strings.ToLower(dns.Name(s.Req.Question[0].Name).String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// QName returns the name of the question in the request.
|
// QName returns the name of the question in the request.
|
||||||
func (c Context) QName() string {
|
func (s State) QName() string {
|
||||||
return dns.Name(c.Req.Question[0].Name).String()
|
return dns.Name(s.Req.Question[0].Name).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Class returns the class of the question in the request.
|
// Class returns the class of the question in the request.
|
||||||
func (c Context) Class() string {
|
func (s State) Class() string {
|
||||||
return dns.Class(c.Req.Question[0].Qclass).String()
|
return dns.Class(s.Req.Question[0].Qclass).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// QClass returns the class of the question in the request.
|
// QClass returns the class of the question in the request.
|
||||||
func (c Context) QClass() uint16 {
|
func (s State) QClass() uint16 {
|
||||||
return c.Req.Question[0].Qclass
|
return s.Req.Question[0].Qclass
|
||||||
}
|
}
|
||||||
|
|
||||||
// More convience types for extracting stuff from a message?
|
|
||||||
// Header?
|
|
||||||
|
|
||||||
// ErrorMessage returns an error message suitable for sending
|
// ErrorMessage returns an error message suitable for sending
|
||||||
// back to the client.
|
// back to the client.
|
||||||
func (c Context) ErrorMessage(rcode int) *dns.Msg {
|
func (s State) ErrorMessage(rcode int) *dns.Msg {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetRcode(c.Req, rcode)
|
m.SetRcode(s.Req, rcode)
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
// AnswerMessage returns an error message suitable for sending
|
// AnswerMessage returns an error message suitable for sending
|
||||||
// back to the client.
|
// back to the client.
|
||||||
func (c Context) AnswerMessage() *dns.Msg {
|
func (s State) AnswerMessage() *dns.Msg {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(c.Req)
|
m.SetReply(s.Req)
|
||||||
return m
|
return m
|
||||||
}
|
}
|
235
middleware/state_test.go
Normal file
235
middleware/state_test.go
Normal file
|
@ -0,0 +1,235 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
/*
|
||||||
|
func TestHeader(t *testing.T) {
|
||||||
|
state := getContextOrFail(t)
|
||||||
|
|
||||||
|
headerKey, headerVal := "Header1", "HeaderVal1"
|
||||||
|
state.Req.Header.Add(headerKey, headerVal)
|
||||||
|
|
||||||
|
actualHeaderVal := state.Header(headerKey)
|
||||||
|
if actualHeaderVal != headerVal {
|
||||||
|
t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
missingHeaderVal := state.Header("not-existing")
|
||||||
|
if missingHeaderVal != "" {
|
||||||
|
t.Errorf("Expected empty header value, found %s", missingHeaderVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIP(t *testing.T) {
|
||||||
|
state := getContextOrFail(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
inputRemoteAddr string
|
||||||
|
expectedIP string
|
||||||
|
}{
|
||||||
|
// Test 0 - ipv4 with port
|
||||||
|
{"1.1.1.1:1111", "1.1.1.1"},
|
||||||
|
// Test 1 - ipv4 without port
|
||||||
|
{"1.1.1.1", "1.1.1.1"},
|
||||||
|
// Test 2 - ipv6 with port
|
||||||
|
{"[::1]:11", "::1"},
|
||||||
|
// Test 3 - ipv6 without port and brackets
|
||||||
|
{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
|
||||||
|
// Test 4 - ipv6 with zone and port
|
||||||
|
{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
testPrefix := getTestPrefix(i)
|
||||||
|
|
||||||
|
state.Req.RemoteAddr = test.inputRemoteAddr
|
||||||
|
actualIP := state.IP()
|
||||||
|
|
||||||
|
if actualIP != test.expectedIP {
|
||||||
|
t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestURL(t *testing.T) {
|
||||||
|
state := getContextOrFail(t)
|
||||||
|
|
||||||
|
inputURL := "http://localhost"
|
||||||
|
state.Req.RequestURI = inputURL
|
||||||
|
|
||||||
|
if inputURL != state.URI() {
|
||||||
|
t.Errorf("Expected url %s, found %s", inputURL, state.URI())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expectedHost string
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
input: "localhost:123",
|
||||||
|
expectedHost: "localhost",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "localhost",
|
||||||
|
expectedHost: "localhost",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "[::]",
|
||||||
|
expectedHost: "",
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPort(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expectedPort string
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
input: "localhost:123",
|
||||||
|
expectedPort: "123",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "localhost",
|
||||||
|
expectedPort: "80", // assuming 80 is the default port
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: ":8080",
|
||||||
|
expectedPort: "8080",
|
||||||
|
shouldErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "[::]",
|
||||||
|
expectedPort: "",
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
|
||||||
|
state := getContextOrFail(t)
|
||||||
|
|
||||||
|
state.Req.Host = input
|
||||||
|
var actualResult, testedObject string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if isTestingHost {
|
||||||
|
actualResult, err = state.Host()
|
||||||
|
testedObject = "host"
|
||||||
|
} else {
|
||||||
|
actualResult, err = state.Port()
|
||||||
|
testedObject = "port"
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldErr && err == nil {
|
||||||
|
t.Errorf("Expected error, found nil!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !shouldErr && err != nil {
|
||||||
|
t.Errorf("Expected no error, found %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if actualResult != expectedResult {
|
||||||
|
t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPathMatches(t *testing.T) {
|
||||||
|
state := getContextOrFail(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
urlStr string
|
||||||
|
pattern string
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
// Test 0
|
||||||
|
{
|
||||||
|
urlStr: "http://localhost/",
|
||||||
|
pattern: "",
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
// Test 1
|
||||||
|
{
|
||||||
|
urlStr: "http://localhost",
|
||||||
|
pattern: "",
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
// Test 1
|
||||||
|
{
|
||||||
|
urlStr: "http://localhost/",
|
||||||
|
pattern: "/",
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
// Test 3
|
||||||
|
{
|
||||||
|
urlStr: "http://localhost/?param=val",
|
||||||
|
pattern: "/",
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
// Test 4
|
||||||
|
{
|
||||||
|
urlStr: "http://localhost/dir1/dir2",
|
||||||
|
pattern: "/dir2",
|
||||||
|
shouldMatch: false,
|
||||||
|
},
|
||||||
|
// Test 5
|
||||||
|
{
|
||||||
|
urlStr: "http://localhost/dir1/dir2",
|
||||||
|
pattern: "/dir1",
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
// Test 6
|
||||||
|
{
|
||||||
|
urlStr: "http://localhost:444/dir1/dir2",
|
||||||
|
pattern: "/dir1",
|
||||||
|
shouldMatch: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
testPrefix := getTestPrefix(i)
|
||||||
|
var err error
|
||||||
|
state.Req.URL, err = url.Parse(test.urlStr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := state.PathMatches(test.pattern)
|
||||||
|
if matches != test.shouldMatch {
|
||||||
|
t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initTestContext() (Context, error) {
|
||||||
|
body := bytes.NewBufferString("request body")
|
||||||
|
request, err := http.NewRequest("GET", "https://localhost", body)
|
||||||
|
if err != nil {
|
||||||
|
return Context{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func getTestPrefix(testN int) string {
|
||||||
|
return fmt.Sprintf("Test [%d]: ", testN)
|
||||||
|
}
|
||||||
|
*/
|
|
@ -15,6 +15,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -285,6 +287,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
q := r.Question[0].Name
|
q := r.Question[0].Name
|
||||||
b := make([]byte, len(q))
|
b := make([]byte, len(q))
|
||||||
off, end := 0, false
|
off, end := 0, false
|
||||||
|
ctx := context.Background()
|
||||||
for {
|
for {
|
||||||
l := len(q[off:])
|
l := len(q[off:])
|
||||||
for i := 0; i < l; i++ {
|
for i := 0; i < l; i++ {
|
||||||
|
@ -297,7 +300,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
|
||||||
if h, ok := s.zones[string(b[:l])]; ok {
|
if h, ok := s.zones[string(b[:l])]; ok {
|
||||||
if r.Question[0].Qtype != dns.TypeDS {
|
if r.Question[0].Qtype != dns.TypeDS {
|
||||||
rcode, _ := h.stack.ServeDNS(w, r)
|
rcode, _ := h.stack.ServeDNS(ctx, w, r)
|
||||||
if rcode > 0 {
|
if rcode > 0 {
|
||||||
DefaultErrorFunc(w, r, rcode)
|
DefaultErrorFunc(w, r, rcode)
|
||||||
}
|
}
|
||||||
|
@ -311,7 +314,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
}
|
}
|
||||||
// Wildcard match, if we have found nothing try the root zone as a last resort.
|
// Wildcard match, if we have found nothing try the root zone as a last resort.
|
||||||
if h, ok := s.zones["."]; ok {
|
if h, ok := s.zones["."]; ok {
|
||||||
rcode, _ := h.stack.ServeDNS(w, r)
|
rcode, _ := h.stack.ServeDNS(ctx, w, r)
|
||||||
if rcode > 0 {
|
if rcode > 0 {
|
||||||
DefaultErrorFunc(w, r, rcode)
|
DefaultErrorFunc(w, r, rcode)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue