diff --git a/core/https/https_test.go b/core/https/https_test.go
index 40b67367e..f19b3cde0 100644
--- a/core/https/https_test.go
+++ b/core/https/https_test.go
@@ -1,16 +1,6 @@
package https
-import (
- "io/ioutil"
- "net/http"
- "os"
- "testing"
-
- "github.com/miekg/coredns/middleware/redirect"
- "github.com/miekg/coredns/server"
- "github.com/xenolf/lego/acme"
-)
-
+/*
func TestHostQualifies(t *testing.T) {
for i, test := range []struct {
host string
@@ -330,3 +320,4 @@ func TestMarkQualified(t *testing.T) {
t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count)
}
}
+*/
diff --git a/core/setup/controller.go b/core/setup/controller.go
index 1c1a93e64..7f8da6721 100644
--- a/core/setup/controller.go
+++ b/core/setup/controller.go
@@ -4,6 +4,8 @@ import (
"fmt"
"strings"
+ "golang.org/x/net/context"
+
"github.com/miekg/coredns/core/parse"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/server"
@@ -70,7 +72,7 @@ func NewTestController(input string) *Controller {
//
// Used primarily for testing but needs to be exported so
// add-ons can use this as a convenience.
-var EmptyNext = middleware.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+var EmptyNext = middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return 0, nil
})
diff --git a/core/setup/errors_test.go b/core/setup/errors_test.go
index 4a079e0b5..b4aaab080 100644
--- a/core/setup/errors_test.go
+++ b/core/setup/errors_test.go
@@ -1,12 +1,6 @@
package setup
-import (
- "testing"
-
- "github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/errors"
-)
-
+/*
func TestErrors(t *testing.T) {
c := NewTestController(`errors`)
mid, err := Errors(c)
@@ -154,5 +148,5 @@ func TestErrorsParse(t *testing.T) {
}
}
}
-
}
+*/
diff --git a/core/setup/rewrite_test.go b/core/setup/rewrite_test.go
index 747618305..5345c4bf6 100644
--- a/core/setup/rewrite_test.go
+++ b/core/setup/rewrite_test.go
@@ -1,13 +1,6 @@
package setup
-import (
- "fmt"
- "regexp"
- "testing"
-
- "github.com/miekg/coredns/middleware/rewrite"
-)
-
+/*
func TestRewrite(t *testing.T) {
c := NewTestController(`rewrite /from /to`)
@@ -237,5 +230,5 @@ func TestRewriteParse(t *testing.T) {
}
}
-
}
+*/
diff --git a/middleware/context_test.go b/middleware/context_test.go
deleted file mode 100644
index 689c47c13..000000000
--- a/middleware/context_test.go
+++ /dev/null
@@ -1,613 +0,0 @@
-package middleware
-
-import (
- "bytes"
- "fmt"
- "io/ioutil"
- "net/http"
- "net/url"
- "os"
- "path/filepath"
- "strings"
- "testing"
- "time"
-)
-
-func TestInclude(t *testing.T) {
- context := getContextOrFail(t)
-
- inputFilename := "test_file"
- absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
- defer func() {
- err := os.Remove(absInFilePath)
- if err != nil && !os.IsNotExist(err) {
- t.Fatalf("Failed to clean test file!")
- }
- }()
-
- tests := []struct {
- fileContent string
- expectedContent string
- shouldErr bool
- expectedErrorContent string
- }{
- // Test 0 - all good
- {
- fileContent: `str1 {{ .Root }} str2`,
- expectedContent: fmt.Sprintf("str1 %s str2", context.Root),
- shouldErr: false,
- expectedErrorContent: "",
- },
- // Test 1 - failure on template.Parse
- {
- fileContent: `str1 {{ .Root } str2`,
- expectedContent: "",
- shouldErr: true,
- expectedErrorContent: `unexpected "}" in operand`,
- },
- // Test 3 - failure on template.Execute
- {
- fileContent: `str1 {{ .InvalidField }} str2`,
- expectedContent: "",
- shouldErr: true,
- expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`,
- },
- }
-
- for i, test := range tests {
- testPrefix := getTestPrefix(i)
-
- // WriteFile truncates the contentt
- err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
- if err != nil {
- t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
- }
-
- content, err := context.Include(inputFilename)
- if err != nil {
- if !test.shouldErr {
- t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error())
- }
- if !strings.Contains(err.Error(), test.expectedErrorContent) {
- t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error())
- }
- }
-
- if err == nil && test.shouldErr {
- t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename)
- }
-
- if content != test.expectedContent {
- t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
- }
- }
-}
-
-func TestIncludeNotExisting(t *testing.T) {
- context := getContextOrFail(t)
-
- _, err := context.Include("not_existing")
- if err == nil {
- t.Errorf("Expected error but found nil!")
- }
-}
-
-func TestMarkdown(t *testing.T) {
- context := getContextOrFail(t)
-
- inputFilename := "test_file"
- absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
- defer func() {
- err := os.Remove(absInFilePath)
- if err != nil && !os.IsNotExist(err) {
- t.Fatalf("Failed to clean test file!")
- }
- }()
-
- tests := []struct {
- fileContent string
- expectedContent string
- }{
- // Test 0 - test parsing of markdown
- {
- fileContent: "* str1\n* str2\n",
- expectedContent: "
\n",
- },
- }
-
- for i, test := range tests {
- testPrefix := getTestPrefix(i)
-
- // WriteFile truncates the contentt
- err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
- if err != nil {
- t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
- }
-
- content, _ := context.Markdown(inputFilename)
- if content != test.expectedContent {
- t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
- }
- }
-}
-
-func TestCookie(t *testing.T) {
-
- tests := []struct {
- cookie *http.Cookie
- cookieName string
- expectedValue string
- }{
- // Test 0 - happy path
- {
- cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
- cookieName: "cookieName",
- expectedValue: "cookieValue",
- },
- // Test 1 - try to get a non-existing cookie
- {
- cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
- cookieName: "notExisting",
- expectedValue: "",
- },
- // Test 2 - partial name match
- {
- cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"},
- cookieName: "cook",
- expectedValue: "",
- },
- // Test 3 - cookie with optional fields
- {
- cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120},
- cookieName: "cookie",
- expectedValue: "cookieValue",
- },
- }
-
- for i, test := range tests {
- testPrefix := getTestPrefix(i)
-
- // reinitialize the context for each test
- context := getContextOrFail(t)
-
- context.Req.AddCookie(test.cookie)
-
- actualCookieVal := context.Cookie(test.cookieName)
-
- if actualCookieVal != test.expectedValue {
- t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName)
- }
- }
-}
-
-func TestCookieMultipleCookies(t *testing.T) {
- context := getContextOrFail(t)
-
- cookieNameBase, cookieValueBase := "cookieName", "cookieValue"
-
- // make sure that there's no state and multiple requests for different cookies return the correct result
- for i := 0; i < 10; i++ {
- context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)})
- }
-
- for i := 0; i < 10; i++ {
- expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i)
- actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
- if actualCookieVal != expectedCookieVal {
- t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal)
- }
- }
-}
-
-func TestHeader(t *testing.T) {
- context := getContextOrFail(t)
-
- headerKey, headerVal := "Header1", "HeaderVal1"
- context.Req.Header.Add(headerKey, headerVal)
-
- actualHeaderVal := context.Header(headerKey)
- if actualHeaderVal != headerVal {
- t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
- }
-
- missingHeaderVal := context.Header("not-existing")
- if missingHeaderVal != "" {
- t.Errorf("Expected empty header value, found %s", missingHeaderVal)
- }
-}
-
-func TestIP(t *testing.T) {
- context := getContextOrFail(t)
-
- tests := []struct {
- inputRemoteAddr string
- expectedIP string
- }{
- // Test 0 - ipv4 with port
- {"1.1.1.1:1111", "1.1.1.1"},
- // Test 1 - ipv4 without port
- {"1.1.1.1", "1.1.1.1"},
- // Test 2 - ipv6 with port
- {"[::1]:11", "::1"},
- // Test 3 - ipv6 without port and brackets
- {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
- // Test 4 - ipv6 with zone and port
- {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
- }
-
- for i, test := range tests {
- testPrefix := getTestPrefix(i)
-
- context.Req.RemoteAddr = test.inputRemoteAddr
- actualIP := context.IP()
-
- if actualIP != test.expectedIP {
- t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
- }
- }
-}
-
-func TestURL(t *testing.T) {
- context := getContextOrFail(t)
-
- inputURL := "http://localhost"
- context.Req.RequestURI = inputURL
-
- if inputURL != context.URI() {
- t.Errorf("Expected url %s, found %s", inputURL, context.URI())
- }
-}
-
-func TestHost(t *testing.T) {
- tests := []struct {
- input string
- expectedHost string
- shouldErr bool
- }{
- {
- input: "localhost:123",
- expectedHost: "localhost",
- shouldErr: false,
- },
- {
- input: "localhost",
- expectedHost: "localhost",
- shouldErr: false,
- },
- {
- input: "[::]",
- expectedHost: "",
- shouldErr: true,
- },
- }
-
- for _, test := range tests {
- testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
- }
-}
-
-func TestPort(t *testing.T) {
- tests := []struct {
- input string
- expectedPort string
- shouldErr bool
- }{
- {
- input: "localhost:123",
- expectedPort: "123",
- shouldErr: false,
- },
- {
- input: "localhost",
- expectedPort: "80", // assuming 80 is the default port
- shouldErr: false,
- },
- {
- input: ":8080",
- expectedPort: "8080",
- shouldErr: false,
- },
- {
- input: "[::]",
- expectedPort: "",
- shouldErr: true,
- },
- }
-
- for _, test := range tests {
- testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
- }
-}
-
-func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
- context := getContextOrFail(t)
-
- context.Req.Host = input
- var actualResult, testedObject string
- var err error
-
- if isTestingHost {
- actualResult, err = context.Host()
- testedObject = "host"
- } else {
- actualResult, err = context.Port()
- testedObject = "port"
- }
-
- if shouldErr && err == nil {
- t.Errorf("Expected error, found nil!")
- return
- }
-
- if !shouldErr && err != nil {
- t.Errorf("Expected no error, found %s", err)
- return
- }
-
- if actualResult != expectedResult {
- t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
- }
-}
-
-func TestMethod(t *testing.T) {
- context := getContextOrFail(t)
-
- method := "POST"
- context.Req.Method = method
-
- if method != context.Method() {
- t.Errorf("Expected method %s, found %s", method, context.Method())
- }
-
-}
-
-func TestPathMatches(t *testing.T) {
- context := getContextOrFail(t)
-
- tests := []struct {
- urlStr string
- pattern string
- shouldMatch bool
- }{
- // Test 0
- {
- urlStr: "http://localhost/",
- pattern: "",
- shouldMatch: true,
- },
- // Test 1
- {
- urlStr: "http://localhost",
- pattern: "",
- shouldMatch: true,
- },
- // Test 1
- {
- urlStr: "http://localhost/",
- pattern: "/",
- shouldMatch: true,
- },
- // Test 3
- {
- urlStr: "http://localhost/?param=val",
- pattern: "/",
- shouldMatch: true,
- },
- // Test 4
- {
- urlStr: "http://localhost/dir1/dir2",
- pattern: "/dir2",
- shouldMatch: false,
- },
- // Test 5
- {
- urlStr: "http://localhost/dir1/dir2",
- pattern: "/dir1",
- shouldMatch: true,
- },
- // Test 6
- {
- urlStr: "http://localhost:444/dir1/dir2",
- pattern: "/dir1",
- shouldMatch: true,
- },
- // Test 7
- {
- urlStr: "http://localhost/dir1/dir2",
- pattern: "*/dir2",
- shouldMatch: false,
- },
- }
-
- for i, test := range tests {
- testPrefix := getTestPrefix(i)
- var err error
- context.Req.URL, err = url.Parse(test.urlStr)
- if err != nil {
- t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
- }
-
- matches := context.PathMatches(test.pattern)
- if matches != test.shouldMatch {
- t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
- }
- }
-}
-
-func TestTruncate(t *testing.T) {
- context := getContextOrFail(t)
- tests := []struct {
- inputString string
- inputLength int
- expected string
- }{
- // Test 0 - small length
- {
- inputString: "string",
- inputLength: 1,
- expected: "s",
- },
- // Test 1 - exact length
- {
- inputString: "string",
- inputLength: 6,
- expected: "string",
- },
- // Test 2 - bigger length
- {
- inputString: "string",
- inputLength: 10,
- expected: "string",
- },
- // Test 3 - zero length
- {
- inputString: "string",
- inputLength: 0,
- expected: "",
- },
- // Test 4 - negative, smaller length
- {
- inputString: "string",
- inputLength: -5,
- expected: "tring",
- },
- // Test 5 - negative, exact length
- {
- inputString: "string",
- inputLength: -6,
- expected: "string",
- },
- // Test 6 - negative, bigger length
- {
- inputString: "string",
- inputLength: -7,
- expected: "string",
- },
- }
-
- for i, test := range tests {
- actual := context.Truncate(test.inputString, test.inputLength)
- if actual != test.expected {
- t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength)
- }
- }
-}
-
-func TestStripHTML(t *testing.T) {
- context := getContextOrFail(t)
- tests := []struct {
- input string
- expected string
- }{
- // Test 0 - no tags
- {
- input: `h1`,
- expected: `h1`,
- },
- // Test 1 - happy path
- {
- input: `h1
`,
- expected: `h1`,
- },
- // Test 2 - tag in quotes
- {
- input: `">h1
`,
- expected: `h1`,
- },
- // Test 3 - multiple tags
- {
- input: `h1
`,
- expected: `h1`,
- },
- // Test 4 - tags not closed
- {
- input: `hi`,
- expected: ` 0:
- answer = context.AnswerMessage()
+ answer = state.AnswerMessage()
answer.Answer = names
default:
- answer = context.ErrorMessage(dns.RcodeServerFailure)
+ answer = state.ErrorMessage(dns.RcodeServerFailure)
}
// Check return size, etc. TODO(miek)
w.WriteMsg(answer)
diff --git a/middleware/file/file_test.go b/middleware/file/file_test.go
index 54584b5cc..00f667d57 100644
--- a/middleware/file/file_test.go
+++ b/middleware/file/file_test.go
@@ -1,15 +1,6 @@
package file
-import (
- "errors"
- "net/http"
- "net/http/httptest"
- "os"
- "path/filepath"
- "strings"
- "testing"
-)
-
+/*
var testDir = filepath.Join(os.TempDir(), "caddy_testdir")
var ErrCustom = errors.New("Custom Error")
@@ -323,3 +314,4 @@ func TestServeHTTPFailingStat(t *testing.T) {
}
}
}
+*/
diff --git a/middleware/log/log.go b/middleware/log/log.go
index 109add9f5..c0f960fed 100644
--- a/middleware/log/log.go
+++ b/middleware/log/log.go
@@ -4,6 +4,8 @@ package log
import (
"log"
+ "golang.org/x/net/context"
+
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
@@ -15,7 +17,7 @@ type Logger struct {
ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler
}
-func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
for _, rule := range l.Rules {
/*
if middleware.Path(r.URL.Path).Matches(rule.PathScope) {
@@ -40,7 +42,7 @@ func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
*/
rule = rule
}
- return l.Next.ServeDNS(w, r)
+ return l.Next.ServeDNS(ctx, w, r)
}
// Rule configures the logging middleware.
diff --git a/middleware/log/log_test.go b/middleware/log/log_test.go
index 40560e4c0..6d41c4926 100644
--- a/middleware/log/log_test.go
+++ b/middleware/log/log_test.go
@@ -1,17 +1,9 @@
package log
-import (
- "bytes"
- "log"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-)
-
+/*
type erroringMiddleware struct{}
-func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
+func (erroringMiddleware) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
return http.StatusNotFound, nil
}
@@ -46,3 +38,4 @@ func TestLoggedStatus(t *testing.T) {
t.Error("Expected 404 to be logged. Logged string -", logged)
}
}
+*/
diff --git a/middleware/middleware.go b/middleware/middleware.go
index 436ec86e9..1ce5f62d6 100644
--- a/middleware/middleware.go
+++ b/middleware/middleware.go
@@ -5,6 +5,7 @@ import (
"time"
"github.com/miekg/dns"
+ "golang.org/x/net/context"
)
type (
@@ -32,18 +33,18 @@ type (
// Otherwise, return values should be propagated down the middleware
// chain by returning them unchanged.
Handler interface {
- ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error)
+ ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
}
// HandlerFunc is a convenience type like dns.HandlerFunc, except
// ServeDNS returns an rcode and an error. See Handler
// documentation for more information.
- HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error)
+ HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
)
// ServeDNS implements the Handler interface.
-func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
- return f(w, r)
+func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ return f(ctx, w, r)
}
// IndexFile looks for a file in /root/fpath/indexFile for each string
diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go
index 62fa4e250..c870d7c16 100644
--- a/middleware/middleware_test.go
+++ b/middleware/middleware_test.go
@@ -1,108 +1 @@
package middleware
-
-import (
- "fmt"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
-)
-
-func TestIndexfile(t *testing.T) {
- tests := []struct {
- rootDir http.FileSystem
- fpath string
- indexFiles []string
- shouldErr bool
- expectedFilePath string //retun value
- expectedBoolValue bool //return value
- }{
- {
- http.Dir("./templates/testdata"),
- "/images/",
- []string{"img.htm"},
- false,
- "/images/img.htm",
- true,
- },
- }
- for i, test := range tests {
- actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles)
- if actualBoolValue == true && test.shouldErr {
- t.Errorf("Test %d didn't error, but it should have", i)
- } else if actualBoolValue != true && !test.shouldErr {
- t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist")
- }
- if actualFilePath != test.expectedFilePath {
- t.Fatalf("Test %d expected returned filepath to be %s, but got %s ",
- i, test.expectedFilePath, actualFilePath)
-
- }
- if actualBoolValue != test.expectedBoolValue {
- t.Fatalf("Test %d expected returned bool value to be %v, but got %v ",
- i, test.expectedBoolValue, actualBoolValue)
-
- }
- }
-}
-
-func TestSetLastModified(t *testing.T) {
- nowTime := time.Now()
-
- // ovewrite the function to return reliable time
- originalGetCurrentTimeFunc := currentTime
- currentTime = func() time.Time {
- return nowTime
- }
- defer func() {
- currentTime = originalGetCurrentTimeFunc
- }()
-
- pastTime := nowTime.Truncate(1 * time.Hour)
- futureTime := nowTime.Add(1 * time.Hour)
-
- tests := []struct {
- inputModTime time.Time
- expectedIsHeaderSet bool
- expectedLastModified string
- }{
- {
- inputModTime: pastTime,
- expectedIsHeaderSet: true,
- expectedLastModified: pastTime.UTC().Format(http.TimeFormat),
- },
- {
- inputModTime: nowTime,
- expectedIsHeaderSet: true,
- expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
- },
- {
- inputModTime: futureTime,
- expectedIsHeaderSet: true,
- expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
- },
- {
- inputModTime: time.Time{},
- expectedIsHeaderSet: false,
- },
- }
-
- for i, test := range tests {
- responseRecorder := httptest.NewRecorder()
- errorPrefix := fmt.Sprintf("Test [%d]: ", i)
- SetLastModifiedHeader(responseRecorder, test.inputModTime)
- actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified")
-
- if test.expectedIsHeaderSet && actualLastModifiedHeader == "" {
- t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing")
- }
-
- if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" {
- t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader)
- }
-
- if test.expectedLastModified != actualLastModifiedHeader {
- t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader)
- }
- }
-}
diff --git a/middleware/prometheus/handler.go b/middleware/prometheus/handler.go
index eb82b8aff..a0cfcc872 100644
--- a/middleware/prometheus/handler.go
+++ b/middleware/prometheus/handler.go
@@ -4,15 +4,17 @@ import (
"strconv"
"time"
+ "golang.org/x/net/context"
+
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
-func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
- context := middleware.Context{W: w, Req: r}
+func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ state := middleware.State{W: w, Req: r}
- qname := context.Name()
- qtype := context.Type()
+ qname := state.Name()
+ qtype := state.Type()
zone := middleware.Zones(m.ZoneNames).Matches(qname)
if zone == "" {
zone = "."
@@ -20,7 +22,7 @@ func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
// Record response to get status code and size of the reply.
rw := middleware.NewResponseRecorder(w)
- status, err := m.Next.ServeDNS(rw, r)
+ status, err := m.Next.ServeDNS(ctx, rw, r)
requestCount.WithLabelValues(zone, qtype).Inc()
requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second))
diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go
index 169e41b61..dfdfd082e 100644
--- a/middleware/proxy/proxy.go
+++ b/middleware/proxy/proxy.go
@@ -7,6 +7,8 @@ import (
"sync/atomic"
"time"
+ "golang.org/x/net/context"
+
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
@@ -67,7 +69,7 @@ func (uh *UpstreamHost) Down() bool {
var tryDuration = 60 * time.Second
// ServeDNS satisfies the middleware.Handler interface.
-func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+func (p Proxy) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
for _, upstream := range p.Upstreams {
// allowed bla bla bla TODO(miek): fix full proxy spec from caddy
start := time.Now()
@@ -100,7 +102,7 @@ func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
}
return dns.RcodeServerFailure, errUnreachable
}
- return p.Next.ServeDNS(w, r)
+ return p.Next.ServeDNS(ctx, w, r)
}
func Clients() Client {
diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go
index 8066874d2..e4b403254 100644
--- a/middleware/proxy/proxy_test.go
+++ b/middleware/proxy/proxy_test.go
@@ -1,26 +1,6 @@
package proxy
-import (
- "bufio"
- "bytes"
- "fmt"
- "io"
- "io/ioutil"
- "log"
- "net"
- "net/http"
- "net/http/httptest"
- "net/url"
- "os"
- "path/filepath"
- "runtime"
- "strings"
- "testing"
- "time"
-
- "golang.org/x/net/websocket"
-)
-
+/*
func init() {
tryDuration = 50 * time.Millisecond // prevent tests from hanging
}
@@ -315,3 +295,4 @@ func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
+*/
diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go
index 6d27da042..460b294a8 100644
--- a/middleware/proxy/reverseproxy.go
+++ b/middleware/proxy/reverseproxy.go
@@ -12,15 +12,15 @@ type ReverseProxy struct {
}
func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error {
- // TODO(miek): use extra!
+ // TODO(miek): use extra to EDNS0.
var (
reply *dns.Msg
err error
)
- context := middleware.Context{W: w, Req: r}
+ state := middleware.State{W: w, Req: r}
// tls+tcp ?
- if context.Proto() == "tcp" {
+ if state.Proto() == "tcp" {
reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
} else {
reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
diff --git a/middleware/recorder_test.go b/middleware/recorder_test.go
index a8c8a5d04..30931a0e3 100644
--- a/middleware/recorder_test.go
+++ b/middleware/recorder_test.go
@@ -1,11 +1,6 @@
package middleware
-import (
- "net/http"
- "net/http/httptest"
- "testing"
-)
-
+/*
func TestNewResponseRecorder(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
@@ -30,3 +25,4 @@ func TestWrite(t *testing.T) {
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
}
}
+*/
diff --git a/middleware/reflect/reflect.go b/middleware/reflect/reflect.go
index 6d5847b81..6e49c4199 100644
--- a/middleware/reflect/reflect.go
+++ b/middleware/reflect/reflect.go
@@ -20,6 +20,8 @@ import (
"net"
"strings"
+ "golang.org/x/net/context"
+
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
@@ -28,15 +30,15 @@ type Reflect struct {
Next middleware.Handler
}
-func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
- context := middleware.Context{Req: r, W: w}
+func (rl Reflect) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ state := middleware.State{Req: r, W: w}
class := r.Question[0].Qclass
qname := r.Question[0].Name
i, ok := dns.NextLabel(qname, 0)
if strings.ToLower(qname[:i]) != who || ok {
- err := context.ErrorMessage(dns.RcodeFormatError)
+ err := state.ErrorMessage(dns.RcodeFormatError)
w.WriteMsg(err)
return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError])
}
@@ -46,10 +48,10 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
answer.Compress = true
answer.Authoritative = true
- ip := context.IP()
- proto := context.Proto()
- port, _ := context.Port()
- family := context.Family()
+ ip := state.IP()
+ proto := state.Proto()
+ port, _ := state.Port()
+ family := state.Family()
var rr dns.RR
switch family {
@@ -67,7 +69,7 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0}
t.Txt = []string{"Port: " + port + " (" + proto + ")"}
- switch context.Type() {
+ switch state.Type() {
case "TXT":
answer.Answer = append(answer.Answer, t)
answer.Extra = append(answer.Extra, rr)
diff --git a/middleware/replacer.go b/middleware/replacer.go
index 133da74c5..03ebecd64 100644
--- a/middleware/replacer.go
+++ b/middleware/replacer.go
@@ -29,19 +29,19 @@ type replacer struct {
// available. emptyValue should be the string that is used
// in place of empty string (can still be empty string).
func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer {
- context := Context{W: rr, Req: r}
+ state := State{W: rr, Req: r}
rep := replacer{
replacements: map[string]string{
- "{type}": context.Type(),
- "{name}": context.Name(),
- "{class}": context.Class(),
- "{proto}": context.Proto(),
+ "{type}": state.Type(),
+ "{name}": state.Name(),
+ "{class}": state.Class(),
+ "{proto}": state.Proto(),
"{when}": func() string {
return time.Now().Format(timeFormat)
}(),
- "{remote}": context.IP(),
+ "{remote}": state.IP(),
"{port}": func() string {
- p, _ := context.Port()
+ p, _ := state.Port()
return p
}(),
},
diff --git a/middleware/replacer_test.go b/middleware/replacer_test.go
index d98bd2de1..378e4083d 100644
--- a/middleware/replacer_test.go
+++ b/middleware/replacer_test.go
@@ -1,12 +1,6 @@
package middleware
-import (
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-)
-
+/*
func TestNewReplacer(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
@@ -122,3 +116,4 @@ func TestSet(t *testing.T) {
t.Error("Expected variable replacement failed")
}
}
+*/
diff --git a/middleware/rewrite/condition_test.go b/middleware/rewrite/condition_test.go
index 3c3b6053a..91004f9d7 100644
--- a/middleware/rewrite/condition_test.go
+++ b/middleware/rewrite/condition_test.go
@@ -1,11 +1,6 @@
package rewrite
-import (
- "net/http"
- "strings"
- "testing"
-)
-
+/*
func TestConditions(t *testing.T) {
tests := []struct {
condition string
@@ -104,3 +99,4 @@ func TestConditions(t *testing.T) {
}
}
}
+*/
diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go
index b3039615b..91b35d236 100644
--- a/middleware/rewrite/rewrite.go
+++ b/middleware/rewrite/rewrite.go
@@ -5,6 +5,7 @@ package rewrite
import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
+ "golang.org/x/net/context"
)
// Result is the result of a rewrite
@@ -27,12 +28,12 @@ type Rewrite struct {
}
// ServeHTTP implements the middleware.Handler interface.
-func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
wr := NewResponseReverter(w, r)
for _, rule := range rw.Rules {
switch result := rule.Rewrite(r); result {
case RewriteDone:
- return rw.Next.ServeDNS(wr, r)
+ return rw.Next.ServeDNS(ctx, wr, r)
case RewriteIgnored:
break
case RewriteStatus:
@@ -42,7 +43,7 @@ func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
// }
}
}
- return rw.Next.ServeDNS(w, r)
+ return rw.Next.ServeDNS(ctx, w, r)
}
// Rule describes an internal location rewrite rule.
diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go
index f57dfd602..b6b01fc94 100644
--- a/middleware/rewrite/rewrite_test.go
+++ b/middleware/rewrite/rewrite_test.go
@@ -1,15 +1,6 @@
package rewrite
-import (
- "fmt"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-
- "github.com/miekg/coredns/middleware"
-)
-
+/*
func TestRewrite(t *testing.T) {
rw := Rewrite{
Next: middleware.HandlerFunc(urlPrinter),
@@ -157,3 +148,4 @@ func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprintf(w, r.URL.String())
return 0, nil
}
+*/
diff --git a/middleware/context.go b/middleware/state.go
similarity index 53%
rename from middleware/context.go
rename to middleware/state.go
index 8868c1c03..163a0dae2 100644
--- a/middleware/context.go
+++ b/middleware/state.go
@@ -9,45 +9,44 @@ import (
"github.com/miekg/dns"
)
-// This file contains the context and functions available for
-// use in the templates.
+// This file contains the state nd functions available for use in the templates.
-// Context is the context with which Caddy templates are executed.
-type Context struct {
- Root http.FileSystem // TODO(miek): needed
+// State contains some connection state and is useful in middleware.
+type State struct {
+ Root http.FileSystem // TODO(miek): needed?
Req *dns.Msg
W dns.ResponseWriter
}
// Now returns the current timestamp in the specified format.
-func (c Context) Now(format string) string {
+func (s State) Now(format string) string {
return time.Now().Format(format)
}
// NowDate returns the current date/time that can be used
// in other time functions.
-func (c Context) NowDate() time.Time {
+func (s State) NowDate() time.Time {
return time.Now()
}
// Header gets the value of a header.
-func (c Context) Header() *dns.RR_Header {
+func (s State) Header() *dns.RR_Header {
// TODO(miek)
return nil
}
// IP gets the (remote) IP address of the client making the request.
-func (c Context) IP() string {
- ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String())
+func (s State) IP() string {
+ ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String())
if err != nil {
- return c.W.RemoteAddr().String()
+ return s.W.RemoteAddr().String()
}
return ip
}
// Post gets the (remote) Port of the client making the request.
-func (c Context) Port() (string, error) {
- _, port, err := net.SplitHostPort(c.W.RemoteAddr().String())
+func (s State) Port() (string, error) {
+ _, port, err := net.SplitHostPort(s.W.RemoteAddr().String())
if err != nil {
return "0", err
}
@@ -56,11 +55,11 @@ func (c Context) Port() (string, error) {
// Proto gets the protocol used as the transport. This
// will be udp or tcp.
-func (c Context) Proto() string {
- if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok {
+func (s State) Proto() string {
+ if _, ok := s.W.RemoteAddr().(*net.UDPAddr); ok {
return "udp"
}
- if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok {
+ if _, ok := s.W.RemoteAddr().(*net.TCPAddr); ok {
return "tcp"
}
return "udp"
@@ -68,9 +67,9 @@ func (c Context) Proto() string {
// Family returns the family of the transport.
// 1 for IPv4 and 2 for IPv6.
-func (c Context) Family() int {
+func (s State) Family() int {
var a net.IP
- ip := c.W.RemoteAddr()
+ ip := s.W.RemoteAddr()
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
@@ -85,51 +84,48 @@ func (c Context) Family() int {
}
// Type returns the type of the question as a string.
-func (c Context) Type() string {
- return dns.Type(c.Req.Question[0].Qtype).String()
+func (s State) Type() string {
+ return dns.Type(s.Req.Question[0].Qtype).String()
}
// QType returns the type of the question as a uint16.
-func (c Context) QType() uint16 {
- return c.Req.Question[0].Qtype
+func (s State) QType() uint16 {
+ return s.Req.Question[0].Qtype
}
// Name returns the name of the question in the request. Note
// this name will always have a closing dot and will be lower cased.
-func (c Context) Name() string {
- return strings.ToLower(dns.Name(c.Req.Question[0].Name).String())
+func (s State) Name() string {
+ return strings.ToLower(dns.Name(s.Req.Question[0].Name).String())
}
// QName returns the name of the question in the request.
-func (c Context) QName() string {
- return dns.Name(c.Req.Question[0].Name).String()
+func (s State) QName() string {
+ return dns.Name(s.Req.Question[0].Name).String()
}
// Class returns the class of the question in the request.
-func (c Context) Class() string {
- return dns.Class(c.Req.Question[0].Qclass).String()
+func (s State) Class() string {
+ return dns.Class(s.Req.Question[0].Qclass).String()
}
// QClass returns the class of the question in the request.
-func (c Context) QClass() uint16 {
- return c.Req.Question[0].Qclass
+func (s State) QClass() uint16 {
+ return s.Req.Question[0].Qclass
}
-// More convience types for extracting stuff from a message?
-// Header?
-
// ErrorMessage returns an error message suitable for sending
// back to the client.
-func (c Context) ErrorMessage(rcode int) *dns.Msg {
+func (s State) ErrorMessage(rcode int) *dns.Msg {
m := new(dns.Msg)
- m.SetRcode(c.Req, rcode)
+ m.SetRcode(s.Req, rcode)
return m
}
// AnswerMessage returns an error message suitable for sending
// back to the client.
-func (c Context) AnswerMessage() *dns.Msg {
+func (s State) AnswerMessage() *dns.Msg {
m := new(dns.Msg)
- m.SetReply(c.Req)
+ m.SetReply(s.Req)
return m
}
diff --git a/middleware/state_test.go b/middleware/state_test.go
new file mode 100644
index 000000000..462f43676
--- /dev/null
+++ b/middleware/state_test.go
@@ -0,0 +1,235 @@
+package middleware
+
+/*
+func TestHeader(t *testing.T) {
+ state := getContextOrFail(t)
+
+ headerKey, headerVal := "Header1", "HeaderVal1"
+ state.Req.Header.Add(headerKey, headerVal)
+
+ actualHeaderVal := state.Header(headerKey)
+ if actualHeaderVal != headerVal {
+ t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
+ }
+
+ missingHeaderVal := state.Header("not-existing")
+ if missingHeaderVal != "" {
+ t.Errorf("Expected empty header value, found %s", missingHeaderVal)
+ }
+}
+
+func TestIP(t *testing.T) {
+ state := getContextOrFail(t)
+
+ tests := []struct {
+ inputRemoteAddr string
+ expectedIP string
+ }{
+ // Test 0 - ipv4 with port
+ {"1.1.1.1:1111", "1.1.1.1"},
+ // Test 1 - ipv4 without port
+ {"1.1.1.1", "1.1.1.1"},
+ // Test 2 - ipv6 with port
+ {"[::1]:11", "::1"},
+ // Test 3 - ipv6 without port and brackets
+ {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
+ // Test 4 - ipv6 with zone and port
+ {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+
+ state.Req.RemoteAddr = test.inputRemoteAddr
+ actualIP := state.IP()
+
+ if actualIP != test.expectedIP {
+ t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
+ }
+ }
+}
+
+func TestURL(t *testing.T) {
+ state := getContextOrFail(t)
+
+ inputURL := "http://localhost"
+ state.Req.RequestURI = inputURL
+
+ if inputURL != state.URI() {
+ t.Errorf("Expected url %s, found %s", inputURL, state.URI())
+ }
+}
+
+func TestHost(t *testing.T) {
+ tests := []struct {
+ input string
+ expectedHost string
+ shouldErr bool
+ }{
+ {
+ input: "localhost:123",
+ expectedHost: "localhost",
+ shouldErr: false,
+ },
+ {
+ input: "localhost",
+ expectedHost: "localhost",
+ shouldErr: false,
+ },
+ {
+ input: "[::]",
+ expectedHost: "",
+ shouldErr: true,
+ },
+ }
+
+ for _, test := range tests {
+ testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
+ }
+}
+
+func TestPort(t *testing.T) {
+ tests := []struct {
+ input string
+ expectedPort string
+ shouldErr bool
+ }{
+ {
+ input: "localhost:123",
+ expectedPort: "123",
+ shouldErr: false,
+ },
+ {
+ input: "localhost",
+ expectedPort: "80", // assuming 80 is the default port
+ shouldErr: false,
+ },
+ {
+ input: ":8080",
+ expectedPort: "8080",
+ shouldErr: false,
+ },
+ {
+ input: "[::]",
+ expectedPort: "",
+ shouldErr: true,
+ },
+ }
+
+ for _, test := range tests {
+ testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
+ }
+}
+
+func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
+ state := getContextOrFail(t)
+
+ state.Req.Host = input
+ var actualResult, testedObject string
+ var err error
+
+ if isTestingHost {
+ actualResult, err = state.Host()
+ testedObject = "host"
+ } else {
+ actualResult, err = state.Port()
+ testedObject = "port"
+ }
+
+ if shouldErr && err == nil {
+ t.Errorf("Expected error, found nil!")
+ return
+ }
+
+ if !shouldErr && err != nil {
+ t.Errorf("Expected no error, found %s", err)
+ return
+ }
+
+ if actualResult != expectedResult {
+ t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
+ }
+}
+
+func TestPathMatches(t *testing.T) {
+ state := getContextOrFail(t)
+
+ tests := []struct {
+ urlStr string
+ pattern string
+ shouldMatch bool
+ }{
+ // Test 0
+ {
+ urlStr: "http://localhost/",
+ pattern: "",
+ shouldMatch: true,
+ },
+ // Test 1
+ {
+ urlStr: "http://localhost",
+ pattern: "",
+ shouldMatch: true,
+ },
+ // Test 1
+ {
+ urlStr: "http://localhost/",
+ pattern: "/",
+ shouldMatch: true,
+ },
+ // Test 3
+ {
+ urlStr: "http://localhost/?param=val",
+ pattern: "/",
+ shouldMatch: true,
+ },
+ // Test 4
+ {
+ urlStr: "http://localhost/dir1/dir2",
+ pattern: "/dir2",
+ shouldMatch: false,
+ },
+ // Test 5
+ {
+ urlStr: "http://localhost/dir1/dir2",
+ pattern: "/dir1",
+ shouldMatch: true,
+ },
+ // Test 6
+ {
+ urlStr: "http://localhost:444/dir1/dir2",
+ pattern: "/dir1",
+ shouldMatch: true,
+ },
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+ var err error
+ state.Req.URL, err = url.Parse(test.urlStr)
+ if err != nil {
+ t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
+ }
+
+ matches := state.PathMatches(test.pattern)
+ if matches != test.shouldMatch {
+ t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
+ }
+ }
+}
+
+func initTestContext() (Context, error) {
+ body := bytes.NewBufferString("request body")
+ request, err := http.NewRequest("GET", "https://localhost", body)
+ if err != nil {
+ return Context{}, err
+ }
+
+ return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
+}
+
+
+func getTestPrefix(testN int) string {
+ return fmt.Sprintf("Test [%d]: ", testN)
+}
+*/
diff --git a/server/server.go b/server/server.go
index 7baa74686..0db2614bb 100644
--- a/server/server.go
+++ b/server/server.go
@@ -15,6 +15,8 @@ import (
"sync"
"time"
+ "golang.org/x/net/context"
+
"github.com/miekg/dns"
)
@@ -285,6 +287,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
q := r.Question[0].Name
b := make([]byte, len(q))
off, end := 0, false
+ ctx := context.Background()
for {
l := len(q[off:])
for i := 0; i < l; i++ {
@@ -297,7 +300,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if h, ok := s.zones[string(b[:l])]; ok {
if r.Question[0].Qtype != dns.TypeDS {
- rcode, _ := h.stack.ServeDNS(w, r)
+ rcode, _ := h.stack.ServeDNS(ctx, w, r)
if rcode > 0 {
DefaultErrorFunc(w, r, rcode)
}
@@ -311,7 +314,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
// Wildcard match, if we have found nothing try the root zone as a last resort.
if h, ok := s.zones["."]; ok {
- rcode, _ := h.stack.ServeDNS(w, r)
+ rcode, _ := h.stack.ServeDNS(ctx, w, r)
if rcode > 0 {
DefaultErrorFunc(w, r, rcode)
}