diff --git a/.travis.yml b/.travis.yml
index bcd0ff760..d10764882 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -12,7 +12,7 @@ go:
# In the Travis VM-based build environment, IPv6 networking is not
# enabled by default. The sysctl operations below enable IPv6.
# IPv6 is needed by some of the CoreDNS test cases. The VM environment
-# is needed to have access to sudo in the test environment. Sudo is
+# is needed to have access to sudo in the test environment. Sudo is
# needed to have docker in the test environment. Docker is needed to
# launch a kubernetes instance in the test environment.
# (Dependencies are fun! :) )
@@ -38,9 +38,9 @@ before_script:
- if which docker &>/dev/null ; then docker pull gcr.io/google_containers/hyperkube-amd64:v1.2.4 ; docker ps -a ; fi
- if which docker &>/dev/null ; then ./contrib/kubernetes/testscripts/start_k8s_with_services.sh ; docker ps -a ; fi
# Get golang dependencies, and build coredns binary
- - go get -v -d
+ - go get -v -d ./...
- go get github.com/coreos/go-etcd/etcd
- - go build -v -ldflags="-s -w"
+ #- go build -v -ldflags="-s -w"
script:
- go test -tags etcd -race -bench=. ./...
diff --git a/Makefile b/Makefile
index d8cf42b45..d455840a6 100644
--- a/Makefile
+++ b/Makefile
@@ -1,39 +1,63 @@
#BUILD_VERBOSE :=
BUILD_VERBOSE := -v
-TEST_VERBOSE :=
+#TEST_VERBOSE :=
TEST_VERBOSE := -v
DOCKER_IMAGE_NAME := $$USER/coredns
+all: coredns
-all:
+# Phony this to ensure we always build the binary.
+# TODO: Add .go file dependencies.
+.PHONY: coredns
+coredns: generate deps
go build $(BUILD_VERBOSE) -ldflags="-s -w"
.PHONY: docker
-docker: all
- GOOS=linux go build -a -tags netgo -installsuffix netgo -ldflags="-s -w"
+docker: deps
+ CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w"
docker build -t $(DOCKER_IMAGE_NAME) .
+../../mholt/caddy:
+ # Get caddy so we can generate into that codebase
+ # before getting all other dependencies.
+ go get ${BUILD_VERBOSE} github.com/mholt/caddy
+
+.PHONY: generate
+generate: ../../mholt/caddy
+ go generate $(BUILD_VERSOSE)
+
.PHONY: deps
-deps:
+deps: generate
go get ${BUILD_VERBOSE}
.PHONY: test
-test:
+test: deps
go test $(TEST_VERBOSE) ./...
.PHONY: testk8s
-testk8s:
+testk8s: deps
# With -args --v=100 the k8s API response data will be printed in the log:
#go test $(TEST_VERBOSE) -tags=k8s -run 'TestK8sIntegration' ./test -args --v=100
# Without the k8s API response data:
go test $(TEST_VERBOSE) -tags=k8s -run 'TestK8sIntegration' ./test
.PHONY: testk8s-setup
-testk8s-setup:
- go test -v ./core/setup -run TestKubernetes
+testk8s-setup: deps
+ go test -v ./middleware/kubernetes/... -run TestKubernetes
.PHONY: clean
clean:
go clean
+ rm -f coredns
+
+.PHONY: distclean
+distclean: clean
+ # Clean all dependencies and build artifacts
+ find $(GOPATH)/pkg -maxdepth 1 -mindepth 1 | xargs rm -rf
+ find $(GOPATH)/bin -maxdepth 1 -mindepth 1 | xargs rm -rf
+
+ find $(GOPATH)/src -maxdepth 1 -mindepth 1 | grep -v github | xargs rm -rf
+ find $(GOPATH)/src -maxdepth 2 -mindepth 2 | grep -v miekg | xargs rm -rf
+ find $(GOPATH)/src/github.com/miekg -maxdepth 1 -mindepth 1 \! -name \*coredns\* | xargs rm -rf
diff --git a/README.md b/README.md
index e887e0875..47824aeb2 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,8 @@
# CoreDNS
CoreDNS is DNS server that started as a fork of [Caddy](https://github.com/mholt/caddy/). It has the
-same model: it chains middleware.
+same model: it chains middleware. In fact to similar that CoreDNS is now a server type plugin for
+CAddy, i.e. you'll need Caddy to compile CoreDNS.
CoreDNS is the successor of [SkyDNS](https://github.com/skynetservices/skydns). SkyDNS is a thin
layer that exposes services in etcd in the DNS. CoreDNS builds on this idea and is a generic DNS
@@ -38,7 +39,7 @@ There are still few [issues](https://github.com/miekg/coredns/issues), and work
things fast and reduce the memory usage.
All in all, CoreDNS should be able to provide you with enough functionality to replace parts of
-BIND9, Knot, NSD or PowerDNS.
+BIND9, Knot, NSD or PowerDNS and SkyDNS.
Most documentation is in the source and some blog articles can be [found
here](https://miek.nl/tags/coredns/). If you do want to use CoreDNS in production, please let us
know and how we can help.
@@ -46,6 +47,25 @@ know and how we can help.
is also full of examples on how to structure a Corefile (renamed from
Caddyfile when I forked it).
+## Compilation
+
+CoreDNS (as a servertype plugin for Caddy) has a hard dependency on Caddy - this is *almost* like
+the normal Go dependencies, but with a small twist, caddy (the source) need to know that CoreDNS
+exists and for this we need to add 1 line `_ "github.com/miekg/coredns/core"` to file in caddy.
+
+You have the source of CoreDNS, this should preferably be downloaded under your `$GOPATH`. Get all
+dependencies:
+
+ go get ./...
+
+Then, execute `go generate`, this will patch Caddy to add CoreDNS, and then `go build` as you would
+normally do:
+
+ go generate
+ go build
+
+Should yield a `coredns` binary.
+
## Examples
Start a simple proxy:
@@ -95,15 +115,17 @@ All the above examples are possible with the *current* CoreDNS.
## What remains to be done
-* Website?
-* Logo?
* Optimizations.
* Load testing.
* The [issues](https://github.com/miekg/coredns/issues).
-## Blog
+## Blog and Contact
+
+Website:
+Twitter: `@coredns.io`
+Docs:
+Github:
-
## Systemd service file
@@ -123,7 +145,7 @@ LimitNOFILE=8192
User=coredns
WorkingDirectory=/home/coredns
ExecStartPre=/sbin/setcap cap_net_bind_service=+ep /opt/bin/coredns
-ExecStart=/opt/bin/coredns -pidfile /home/coredns/coredns.pid -conf=/etc/coredns
+ExecStart=/opt/bin/coredns -pidfile /home/coredns/coredns.pid -conf=/etc/coredns/Corefile
ExecReload=/bin/kill -SIGUSR1 $MAINPID
Restart=on-failure
diff --git a/core/assets/path.go b/core/assets/path.go
deleted file mode 100644
index abe4638d0..000000000
--- a/core/assets/path.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package assets
-
-import (
- "os"
- "path/filepath"
- "runtime"
-)
-
-// Path returns the path to the folder
-// where the application may store data. This
-// currently resolves to ~/.coredns
-func Path() string {
- return filepath.Join(userHomeDir(), ".coredns")
-}
-
-// userHomeDir returns the user's home directory according to
-// environment variables.
-//
-// Credit: http://stackoverflow.com/a/7922977/1048862
-func userHomeDir() string {
- if runtime.GOOS == "windows" {
- home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
- if home == "" {
- home = os.Getenv("USERPROFILE")
- }
- return home
- }
- return os.Getenv("HOME")
-}
diff --git a/core/assets/path_test.go b/core/assets/path_test.go
deleted file mode 100644
index 6f2c0bfb7..000000000
--- a/core/assets/path_test.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package assets
-
-import (
- "strings"
- "testing"
-)
-
-func TestPath(t *testing.T) {
- if actual := Path(); !strings.HasSuffix(actual, ".coredns") {
- t.Errorf("Expected path to be a .coredns folder, got: %v", actual)
- }
-}
diff --git a/core/caddy_test.go b/core/caddy_test.go
deleted file mode 100644
index 89ec3d045..000000000
--- a/core/caddy_test.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package core
-
-/*
-func TestCaddyStartStop(t *testing.T) {
- corefile := "localhost:1984"
-
- for i := 0; i < 2; i++ {
- err := Start(CorefileInput{Contents: []byte(corefile)})
- if err != nil {
- t.Fatalf("Error starting, iteration %d: %v", i, err)
- }
-
- client := http.Client{
- Timeout: time.Duration(2 * time.Second),
- }
- resp, err := client.Get("http://localhost:1984")
- if err != nil {
- t.Fatalf("Expected GET request to succeed (iteration %d), but it failed: %v", i, err)
- }
- resp.Body.Close()
-
- err = Stop()
- if err != nil {
- t.Fatalf("Error stopping, iteration %d: %v", i, err)
- }
- }
-}
-*/
diff --git a/core/config.go b/core/config.go
deleted file mode 100644
index eea099a70..000000000
--- a/core/config.go
+++ /dev/null
@@ -1,340 +0,0 @@
-package core
-
-import (
- "bytes"
- "fmt"
- "io"
- "log"
- "net"
- "sync"
-
- "github.com/miekg/coredns/core/parse"
- "github.com/miekg/coredns/core/setup"
- "github.com/miekg/coredns/server"
-)
-
-const (
- // DefaultConfigFile is the name of the configuration file that is loaded
- // by default if no other file is specified.
- DefaultConfigFile = "Corefile"
-)
-
-func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Config, []parse.ServerBlock, int, error) {
- var configs []server.Config
-
- // Each server block represents similar hosts/addresses, since they
- // were grouped together in the Corefile.
- serverBlocks, err := parse.ServerBlocks(filename, input, true)
- if err != nil {
- return nil, nil, 0, err
- }
- if len(serverBlocks) == 0 {
- newInput := DefaultInput()
- serverBlocks, err = parse.ServerBlocks(newInput.Path(), bytes.NewReader(newInput.Body()), true)
- if err != nil {
- return nil, nil, 0, err
- }
- }
-
- var lastDirectiveIndex int // we set up directives in two parts; this stores where we left off
-
- // Iterate each server block and make a config for each one,
- // executing the directives that were parsed in order up to the tls
- // directive; this is because we must activate Let's Encrypt.
- for i, sb := range serverBlocks {
- onces := makeOnces()
- storages := makeStorages()
-
- for j, addr := range sb.Addresses {
- config := server.Config{
- Host: addr.Host,
- Port: addr.Port,
- Root: Root,
- ConfigFile: filename,
- AppName: AppName,
- AppVersion: AppVersion,
- }
-
- // It is crucial that directives are executed in the proper order.
- for k, dir := range directiveOrder {
- // Execute directive if it is in the server block
- if tokens, ok := sb.Tokens[dir.name]; ok {
- // Each setup function gets a controller, from which setup functions
- // get access to the config, tokens, and other state information useful
- // to set up its own host only.
- controller := &setup.Controller{
- Config: &config,
- Dispenser: parse.NewDispenserTokens(filename, tokens),
- OncePerServerBlock: func(f func() error) error {
- var err error
- onces[dir.name].Do(func() {
- err = f()
- })
- return err
- },
- ServerBlockIndex: i,
- ServerBlockHostIndex: j,
- ServerBlockHosts: sb.HostList(),
- ServerBlockStorage: storages[dir.name],
- }
- // execute setup function and append middleware handler, if any
- midware, err := dir.setup(controller)
- if err != nil {
- return nil, nil, lastDirectiveIndex, err
- }
- if midware != nil {
- config.Middleware = append(config.Middleware, midware)
- }
- storages[dir.name] = controller.ServerBlockStorage // persist for this server block
- }
-
- // Stop after TLS setup, since we need to activate Let's Encrypt before continuing;
- // it makes some changes to the configs that middlewares might want to know about.
- if dir.name == "tls" {
- lastDirectiveIndex = k
- break
- }
- }
-
- configs = append(configs, config)
- }
- }
- return configs, serverBlocks, lastDirectiveIndex, nil
-}
-
-// loadConfigs reads input (named filename) and parses it, returning the
-// server configurations in the order they appeared in the input. As part
-// of this, it activates Let's Encrypt for the configs that are produced.
-// Thus, the returned configs are already optimally configured for HTTPS.
-func loadConfigs(filename string, input io.Reader) ([]server.Config, error) {
- configs, serverBlocks, lastDirectiveIndex, err := loadConfigsUpToIncludingTLS(filename, input)
- if err != nil {
- return nil, err
- }
-
- // Now we have all the configs, but they have only been set up to the
- // point of tls. We need to activate Let's Encrypt before setting up
- // the rest of the middlewares so they have correct information regarding
- // TLS configuration, if necessary. (this only appends, so our iterations
- // over server blocks below shouldn't be affected)
- if !IsRestart() && !Quiet {
- fmt.Println("Activating privacy features...")
- }
- /* TODO(miek): stopped for now
- configs, err = https.Activate(configs)
- if err != nil {
- return nil, err
- } else if !IsRestart() && !Quiet {
- fmt.Println(" done.")
- }
- */
-
- // Finish setting up the rest of the directives, now that TLS is
- // optimally configured. These loops are similar to above except
- // we don't iterate all the directives from the beginning and we
- // don't create new configs.
- configIndex := -1
- for i, sb := range serverBlocks {
- onces := makeOnces()
- storages := makeStorages()
-
- for j := range sb.Addresses {
- configIndex++
-
- for k := lastDirectiveIndex + 1; k < len(directiveOrder); k++ {
- dir := directiveOrder[k]
-
- if tokens, ok := sb.Tokens[dir.name]; ok {
- controller := &setup.Controller{
- Config: &configs[configIndex],
- Dispenser: parse.NewDispenserTokens(filename, tokens),
- OncePerServerBlock: func(f func() error) error {
- var err error
- onces[dir.name].Do(func() {
- err = f()
- })
- return err
- },
- ServerBlockIndex: i,
- ServerBlockHostIndex: j,
- ServerBlockHosts: sb.HostList(),
- ServerBlockStorage: storages[dir.name],
- }
- midware, err := dir.setup(controller)
- if err != nil {
- return nil, err
- }
- if midware != nil {
- configs[configIndex].Middleware = append(configs[configIndex].Middleware, midware)
- }
- storages[dir.name] = controller.ServerBlockStorage // persist for this server block
- }
- }
- }
- }
-
- return configs, nil
-}
-
-// makeOnces makes a map of directive name to sync.Once
-// instance. This is intended to be called once per server
-// block when setting up configs so that Setup functions
-// for each directive can perform a task just once per
-// server block, even if there are multiple hosts on the block.
-//
-// We need one Once per directive, otherwise the first
-// directive to use it would exclude other directives from
-// using it at all, which would be a bug.
-func makeOnces() map[string]*sync.Once {
- onces := make(map[string]*sync.Once)
- for _, dir := range directiveOrder {
- onces[dir.name] = new(sync.Once)
- }
- return onces
-}
-
-// makeStorages makes a map of directive name to interface{}
-// so that directives' setup functions can persist state
-// between different hosts on the same server block during the
-// setup phase.
-func makeStorages() map[string]interface{} {
- storages := make(map[string]interface{})
- for _, dir := range directiveOrder {
- storages[dir.name] = nil
- }
- return storages
-}
-
-// arrangeBindings groups configurations by their bind address. For example,
-// a server that should listen on localhost and another on 127.0.0.1 will
-// be grouped into the same address: 127.0.0.1. It will return an error
-// if an address is malformed or a TLS listener is configured on the
-// same address as a plaintext HTTP listener. The return value is a map of
-// bind address to list of configs that would become VirtualHosts on that
-// server. Use the keys of the returned map to create listeners, and use
-// the associated values to set up the virtualhosts.
-func arrangeBindings(allConfigs []server.Config) (bindingGroup, error) {
- var groupings bindingGroup
-
- // Group configs by bind address
- for _, conf := range allConfigs {
- // use default port if none is specified
- if conf.Port == "" {
- conf.Port = Port
- }
-
- bindAddr, warnErr, fatalErr := resolveAddr(conf)
- if fatalErr != nil {
- return groupings, fatalErr
- }
- if warnErr != nil {
- log.Printf("[WARNING] Resolving bind address for %s: %v", conf.Address(), warnErr)
- }
-
- // Make sure to compare the string representation of the address,
- // not the pointer, since a new *TCPAddr is created each time.
- var existing bool
- for i := 0; i < len(groupings); i++ {
- if groupings[i].BindAddr.String() == bindAddr.String() {
- groupings[i].Configs = append(groupings[i].Configs, conf)
- existing = true
- break
- }
- }
- if !existing {
- groupings = append(groupings, bindingMapping{
- BindAddr: bindAddr,
- Configs: []server.Config{conf},
- })
- }
- }
-
- // Don't allow HTTP and HTTPS to be served on the same address
- for _, group := range groupings {
- isTLS := group.Configs[0].TLS.Enabled
- for _, config := range group.Configs {
- if config.TLS.Enabled != isTLS {
- thisConfigProto, otherConfigProto := "HTTP", "HTTP"
- if config.TLS.Enabled {
- thisConfigProto = "HTTPS"
- }
- if group.Configs[0].TLS.Enabled {
- otherConfigProto = "HTTPS"
- }
- return groupings, fmt.Errorf("configuration error: Cannot multiplex %s (%s) and %s (%s) on same address",
- group.Configs[0].Address(), otherConfigProto, config.Address(), thisConfigProto)
- }
- }
- }
-
- return groupings, nil
-}
-
-// resolveAddr determines the address (host and port) that a config will
-// bind to. The returned address, resolvAddr, should be used to bind the
-// listener or group the config with other configs using the same address.
-// The first error, if not nil, is just a warning and should be reported
-// but execution may continue. The second error, if not nil, is a real
-// problem and the server should not be started.
-//
-// This function does not handle edge cases like port "http" or "https" if
-// they are not known to the system. It does, however, serve on the wildcard
-// host if resolving the address of the specific hostname fails.
-func resolveAddr(conf server.Config) (resolvAddr *net.TCPAddr, warnErr, fatalErr error) {
- resolvAddr, warnErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.BindHost, conf.Port))
- if warnErr != nil {
- // the hostname probably couldn't be resolved, just bind to wildcard then
- resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort("", conf.Port))
- if fatalErr != nil {
- return
- }
- }
-
- return
-}
-
-// validDirective returns true if d is a valid
-// directive; false otherwise.
-func validDirective(d string) bool {
- for _, dir := range directiveOrder {
- if dir.name == d {
- return true
- }
- }
- return false
-}
-
-// DefaultInput returns the default Corefile input
-// to use when it is otherwise empty or missing.
-// It uses the default host and port and root.
-func DefaultInput() CorefileInput {
- port := Port
- return CorefileInput{
- Contents: []byte(fmt.Sprintf("%s:%s\nroot %s", Host, port, Root)),
- }
-}
-
-// These defaults are configurable through the command line
-var (
- // Root is the site root
- Root = DefaultRoot
-
- // Host is the site host
- Host = DefaultHost
-
- // Port is the site port
- Port = DefaultPort
-)
-
-// bindingMapping maps a network address to configurations
-// that will bind to it. The order of the configs is important.
-type bindingMapping struct {
- BindAddr *net.TCPAddr
- Configs []server.Config
-}
-
-// bindingGroup maps network addresses to their configurations.
-// Preserving the order of the groupings is important
-// (related to graceful shutdown and restart)
-// so this is a slice, not a literal map.
-type bindingGroup []bindingMapping
diff --git a/core/config_test.go b/core/config_test.go
deleted file mode 100644
index c28dd4fcf..000000000
--- a/core/config_test.go
+++ /dev/null
@@ -1,159 +0,0 @@
-package core
-
-import (
- "reflect"
- "sync"
- "testing"
-
- "github.com/miekg/coredns/server"
-)
-
-func TestDefaultInput(t *testing.T) {
- if actual, expected := string(DefaultInput().Body()), ":53\nroot ."; actual != expected {
- t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
- }
-
- // next few tests simulate user providing -host and/or -port flags
-
- Host = "not-localhost.com"
- if actual, expected := string(DefaultInput().Body()), "not-localhost.com:53\nroot ."; actual != expected {
- t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
- }
-
- Host = "[::1]"
- if actual, expected := string(DefaultInput().Body()), "[::1]:53\nroot ."; actual != expected {
- t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
- }
-
- Host = "127.0.1.1"
- if actual, expected := string(DefaultInput().Body()), "127.0.1.1:53\nroot ."; actual != expected {
- t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
- }
-
- Host = "not-localhost.com"
- Port = "1234"
- if actual, expected := string(DefaultInput().Body()), "not-localhost.com:1234\nroot ."; actual != expected {
- t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
- }
-
- Host = DefaultHost
- Port = "1234"
- if actual, expected := string(DefaultInput().Body()), ":1234\nroot ."; actual != expected {
- t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
- }
-}
-
-func TestResolveAddr(t *testing.T) {
- // NOTE: If tests fail due to comparing to string "127.0.0.1",
- // it's possible that system env resolves with IPv6, or ::1.
- // If that happens, maybe we should use actualAddr.IP.IsLoopback()
- // for the assertion, rather than a direct string comparison.
-
- // NOTE: Tests with {Host: "", Port: ""} and {Host: "localhost", Port: ""}
- // will not behave the same cross-platform, so they have been omitted.
-
- for i, test := range []struct {
- config server.Config
- shouldWarnErr bool
- shouldFatalErr bool
- expectedIP string
- expectedPort int
- }{
- {server.Config{Host: "127.0.0.1", Port: "1234"}, false, false, "", 1234},
- {server.Config{Host: "localhost", Port: "80"}, false, false, "", 80},
- {server.Config{BindHost: "localhost", Port: "1234"}, false, false, "127.0.0.1", 1234},
- {server.Config{BindHost: "127.0.0.1", Port: "1234"}, false, false, "127.0.0.1", 1234},
- {server.Config{BindHost: "should-not-resolve", Port: "1234"}, true, false, "", 1234},
- {server.Config{BindHost: "localhost", Port: "http"}, false, false, "127.0.0.1", 80},
- {server.Config{BindHost: "localhost", Port: "https"}, false, false, "127.0.0.1", 443},
- {server.Config{BindHost: "", Port: "1234"}, false, false, "", 1234},
- {server.Config{BindHost: "localhost", Port: "abcd"}, false, true, "", 0},
- {server.Config{BindHost: "127.0.0.1", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234},
- {server.Config{BindHost: "localhost", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234},
- {server.Config{BindHost: "should-not-resolve", Host: "localhost", Port: "1234"}, true, false, "", 1234},
- } {
- actualAddr, warnErr, fatalErr := resolveAddr(test.config)
-
- if test.shouldFatalErr && fatalErr == nil {
- t.Errorf("Test %d: Expected error, but there wasn't any", i)
- }
- if !test.shouldFatalErr && fatalErr != nil {
- t.Errorf("Test %d: Expected no error, but there was one: %v", i, fatalErr)
- }
- if fatalErr != nil {
- continue
- }
-
- if test.shouldWarnErr && warnErr == nil {
- t.Errorf("Test %d: Expected warning, but there wasn't any", i)
- }
- if !test.shouldWarnErr && warnErr != nil {
- t.Errorf("Test %d: Expected no warning, but there was one: %v", i, warnErr)
- }
-
- if actual, expected := actualAddr.IP.String(), test.expectedIP; actual != expected {
- t.Errorf("Test %d: IP was %s but expected %s", i, actual, expected)
- }
- if actual, expected := actualAddr.Port, test.expectedPort; actual != expected {
- t.Errorf("Test %d: Port was %d but expected %d", i, actual, expected)
- }
- }
-}
-
-func TestMakeOnces(t *testing.T) {
- directives := []directive{
- {"dummy", nil},
- {"dummy2", nil},
- }
- directiveOrder = directives
- onces := makeOnces()
- if len(onces) != len(directives) {
- t.Errorf("onces had len %d , expected %d", len(onces), len(directives))
- }
- expected := map[string]*sync.Once{
- "dummy": new(sync.Once),
- "dummy2": new(sync.Once),
- }
- if !reflect.DeepEqual(onces, expected) {
- t.Errorf("onces was %v, expected %v", onces, expected)
- }
-}
-
-func TestMakeStorages(t *testing.T) {
- directives := []directive{
- {"dummy", nil},
- {"dummy2", nil},
- }
- directiveOrder = directives
- storages := makeStorages()
- if len(storages) != len(directives) {
- t.Errorf("storages had len %d , expected %d", len(storages), len(directives))
- }
- expected := map[string]interface{}{
- "dummy": nil,
- "dummy2": nil,
- }
- if !reflect.DeepEqual(storages, expected) {
- t.Errorf("storages was %v, expected %v", storages, expected)
- }
-}
-
-func TestValidDirective(t *testing.T) {
- directives := []directive{
- {"dummy", nil},
- {"dummy2", nil},
- }
- directiveOrder = directives
- for i, test := range []struct {
- directive string
- valid bool
- }{
- {"dummy", true},
- {"dummy2", true},
- {"dummy3", false},
- } {
- if actual, expected := validDirective(test.directive), test.valid; actual != expected {
- t.Errorf("Test %d: valid was %t, expected %t", i, actual, expected)
- }
- }
-}
diff --git a/core/core.go b/core/core.go
deleted file mode 100644
index 94dd06e52..000000000
--- a/core/core.go
+++ /dev/null
@@ -1,426 +0,0 @@
-// Package core implements the CoreDNS web server as a service
-// in your own Go programs.
-//
-// To use this package, follow a few simple steps:
-//
-// 1. Set the AppName and AppVersion variables.
-// 2. Call LoadCorefile() to get the Corefile (it
-// might have been piped in as part of a restart).
-// You should pass in your own Corefile loader.
-// 3. Call core.Start() to start CoreDNS, core.Stop()
-// to stop it, or core.Restart() to restart it.
-//
-// You should use core.Wait() to wait for all CoreDNS servers
-// to quit before your process exits.
-package core
-
-import (
- "bytes"
- "encoding/gob"
- "errors"
- "fmt"
- "io/ioutil"
- "log"
- "net"
- "os"
- "path"
- "strings"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/miekg/coredns/core/https"
- "github.com/miekg/coredns/server"
-)
-
-// Configurable application parameters
-var (
- // AppName is the name of the application.
- AppName string
-
- // AppVersion is the version of the application.
- AppVersion string
-
- // Quiet when set to true, will not show any informative output on initialization.
- Quiet bool
-
- // PidFile is the path to the pidfile to create.
- PidFile string
-
- // GracefulTimeout is the maximum duration of a graceful shutdown.
- GracefulTimeout time.Duration
-)
-
-var (
- // corefile is the input configuration text used for this process
- corefile Input
-
- // corefileMu protects corefile during changes
- corefileMu sync.Mutex
-
- // errIncompleteRestart occurs if this process is a fork
- // of the parent but no Corefile was piped in
- errIncompleteRestart = errors.New("incomplete restart")
-
- // servers is a list of all the currently-listening servers
- servers []*server.Server
-
- // serversMu protects the servers slice during changes
- serversMu sync.Mutex
-
- // wg is used to wait for all servers to shut down
- wg sync.WaitGroup
-
- // loadedGob is used if this is a child process as part of
- // a graceful restart; it is used to map listeners to their
- // index in the list of inherited file descriptors. This
- // variable is not safe for concurrent access.
- loadedGob corefileGob
-
- // startedBefore should be set to true if CoreDNS has been started
- // at least once (does not indicate whether currently running).
- startedBefore bool
-)
-
-const (
- // DefaultHost is the default host.
- DefaultHost = ""
- // DefaultPort is the default port.
- DefaultPort = "53"
- // DefaultRoot is the default root folder.
- DefaultRoot = "."
-)
-
-// Start starts CoreDNS with the given Corefile. If crfile
-// is nil, the LoadCorefile function will be called to get
-// one.
-//
-// This function blocks until all the servers are listening.
-//
-// Note (POSIX): If Start is called in the child process of a
-// restart more than once within the duration of the graceful
-// cutoff (i.e. the child process called Start a first time,
-// then called Stop, then Start again within the first 5 seconds
-// or however long GracefulTimeout is) and the Corefiles have
-// at least one listener address in common, the second Start
-// may fail with "address already in use" as there's no
-// guarantee that the parent process has relinquished the
-// address before the grace period ends.
-func Start(crfile Input) (err error) {
- // If we return with no errors, we must do two things: tell the
- // parent that we succeeded and write to the pidfile.
- defer func() {
- if err == nil {
- signalSuccessToParent() // TODO: Is doing this more than once per process a bad idea? Start could get called more than once in other apps.
- if PidFile != "" {
- err := writePidFile()
- if err != nil {
- log.Printf("[ERROR] Could not write pidfile: %v", err)
- }
- }
- }
- }()
-
- // Input must never be nil; try to load something
- if crfile == nil {
- crfile, err = LoadCorefile(nil)
- if err != nil {
- return err
- }
- }
-
- corefileMu.Lock()
- corefile = crfile
- corefileMu.Unlock()
-
- // load the server configs (activates Let's Encrypt)
- configs, err := loadConfigs(path.Base(crfile.Path()), bytes.NewReader(crfile.Body()))
- if err != nil {
- return err
- }
-
- // group zones by address
- groupings, err := arrangeBindings(configs)
- if err != nil {
- return err
- }
-
- // Start each server with its one or more configurations
- err = startServers(groupings)
- if err != nil {
- return err
- }
- startedBefore = true
-
- // Show initialization output
- if !Quiet && !IsRestart() {
- var checkedFdLimit bool
- for _, group := range groupings {
- for _, conf := range group.Configs {
- // Print address of site
- fmt.Println(conf.Address())
-
- // Note if non-localhost site resolves to loopback interface
- if group.BindAddr.IP.IsLoopback() && !isLocalhost(conf.Host) {
- fmt.Printf("Notice: %s is only accessible on this machine (%s)\n",
- conf.Host, group.BindAddr.IP.String())
- }
- if !checkedFdLimit && !group.BindAddr.IP.IsLoopback() && !isLocalhost(conf.Host) {
- checkFdlimit()
- checkedFdLimit = true
- }
- }
- }
- }
-
- return nil
-}
-
-// startServers starts all the servers in groupings,
-// taking into account whether or not this process is
-// a child from a graceful restart or not. It blocks
-// until the servers are listening.
-func startServers(groupings bindingGroup) error {
- var startupWg sync.WaitGroup
- errChan := make(chan error, len(groupings)) // must be buffered to allow Serve functions below to return if stopped later
-
- for _, group := range groupings {
- s, err := server.New(group.BindAddr.String(), group.Configs, GracefulTimeout)
- if err != nil {
- return err
- }
- // TODO(miek): does not work, because this callback uses http instead of dns
- // s.ReqCallback = https.RequestCallback // ensures we can solve ACME challenges while running
- if s.OnDemandTLS {
- s.TLSConfig.GetCertificate = https.GetOrObtainCertificate // TLS on demand -- awesome!
- } else {
- s.TLSConfig.GetCertificate = https.GetCertificate
- }
-
- var (
- ln net.Listener
- pc net.PacketConn
- )
-
- if IsRestart() {
- // Look up this server's listener in the map of inherited file descriptors; if we don't have one, we must make a new one (later).
- if fdIndex, ok := loadedGob.ListenerFds["tcp"+s.Addr]; ok {
- file := os.NewFile(fdIndex, "")
-
- fln, err := net.FileListener(file)
- if err != nil {
- return err
- }
-
- ln, ok = fln.(*net.TCPListener)
- if !ok {
- return errors.New("listener for " + s.Addr + " was not a *net.TCPListener")
- }
-
- file.Close()
- delete(loadedGob.ListenerFds, "tcp"+s.Addr)
- }
- if fdIndex, ok := loadedGob.ListenerFds["udp"+s.Addr]; ok {
- file := os.NewFile(fdIndex, "")
-
- fpc, err := net.FilePacketConn(file)
- if err != nil {
- return err
- }
-
- pc, ok = fpc.(*net.UDPConn)
- if !ok {
- return errors.New("packetConn for " + s.Addr + " was not a *net.PacketConn")
- }
-
- file.Close()
- delete(loadedGob.ListenerFds, "udp"+s.Addr)
- }
- }
-
- wg.Add(1)
- go func(s *server.Server, ln net.Listener, pc net.PacketConn) {
- defer wg.Done()
-
- // run startup functions that should only execute when the original parent process is starting.
- if !IsRestart() && !startedBefore {
- err := s.RunFirstStartupFuncs()
- if err != nil {
- errChan <- err
- return
- }
- }
-
- // start the server
- if ln != nil && pc != nil {
- errChan <- s.Serve(ln, pc)
- } else {
- errChan <- s.ListenAndServe()
- }
- }(s, ln, pc)
-
- startupWg.Add(1)
- go func(s *server.Server) {
- defer startupWg.Done()
- s.WaitUntilStarted()
- }(s)
-
- serversMu.Lock()
- servers = append(servers, s)
- serversMu.Unlock()
- }
-
- // Close the remaining (unused) file descriptors to free up resources
- if IsRestart() {
- for key, fdIndex := range loadedGob.ListenerFds {
- os.NewFile(fdIndex, "").Close()
- delete(loadedGob.ListenerFds, key)
- }
- }
-
- // Wait for all servers to finish starting
- startupWg.Wait()
-
- // Return the first error, if any
- select {
- case err := <-errChan:
- // "use of closed network connection" is normal if it was a graceful shutdown
- if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
- return err
- }
- default:
- }
-
- return nil
-}
-
-// Stop stops all servers. It blocks until they are all stopped.
-// It does NOT execute shutdown callbacks that may have been
-// configured by middleware (they must be executed separately).
-func Stop() error {
- https.Deactivate()
-
- serversMu.Lock()
- for _, s := range servers {
- if err := s.Stop(); err != nil {
- log.Printf("[ERROR] Stopping %s: %v", s.Addr, err)
- }
- }
- servers = []*server.Server{} // don't reuse servers
- serversMu.Unlock()
-
- return nil
-}
-
-// Wait blocks until all servers are stopped.
-func Wait() {
- wg.Wait()
-}
-
-// LoadCorefile loads a Corefile, prioritizing a Corefile
-// piped from stdin as part of a restart (only happens on first call
-// to LoadCorefile). If it is not a restart, this function tries
-// calling the user's loader function, and if that returns nil, then
-// this function resorts to the default configuration. Thus, if there
-// are no other errors, this function always returns at least the
-// default Corefile.
-func LoadCorefile(loader func() (Input, error)) (crfile Input, err error) {
- // If we are a fork, finishing the restart is highest priority;
- // piped input is required in this case.
- if IsRestart() {
- err := gob.NewDecoder(os.Stdin).Decode(&loadedGob)
- if err != nil {
- return nil, err
- }
- crfile = loadedGob.Corefile
- atomic.StoreInt32(https.OnDemandIssuedCount, loadedGob.OnDemandTLSCertsIssued)
- }
-
- // Try user's loader
- if crfile == nil && loader != nil {
- crfile, err = loader()
- }
-
- // Otherwise revert to default
- if crfile == nil {
- crfile = DefaultInput()
- }
-
- return
-}
-
-// CorefileFromPipe loads the Corefile input from f if f is
-// not interactive input. f is assumed to be a pipe or stream,
-// such as os.Stdin. If f is not a pipe, no error is returned
-// but the Input value will be nil. An error is only returned
-// if there was an error reading the pipe, even if the length
-// of what was read is 0.
-func CorefileFromPipe(f *os.File) (Input, error) {
- fi, err := f.Stat()
- if err == nil && fi.Mode()&os.ModeCharDevice == 0 {
- // Note that a non-nil error is not a problem. Windows
- // will not create a stdin if there is no pipe, which
- // produces an error when calling Stat(). But Unix will
- // make one either way, which is why we also check that
- // bitmask.
- // BUG: Reading from stdin after this fails (e.g. for the let's encrypt email address) (OS X)
- confBody, err := ioutil.ReadAll(f)
- if err != nil {
- return nil, err
- }
- return CorefileInput{
- Contents: confBody,
- Filepath: f.Name(),
- }, nil
- }
-
- // not having input from the pipe is not itself an error,
- // just means no input to return.
- return nil, nil
-}
-
-// Corefile returns the current Corefile
-func Corefile() Input {
- corefileMu.Lock()
- defer corefileMu.Unlock()
- return corefile
-}
-
-// Input represents a Corefile; its contents and file path
-// (which should include the file name at the end of the path).
-// If path does not apply (e.g. piped input) you may use
-// any understandable value. The path is mainly used for logging,
-// error messages, and debugging.
-type Input interface {
- // Gets the Corefile contents
- Body() []byte
-
- // Gets the path to the origin file
- Path() string
-
- // IsFile returns true if the original input was a file on the file system
- // that could be loaded again later if requested.
- IsFile() bool
-}
-
-// TestServer returns a test server.
-// The ports can be retreived with server.LocalAddr(). The testserver itself can be stopped
-// with Stop(). It just takes a normal Corefile as input.
-func TestServer(t *testing.T, corefile string) (*server.Server, error) {
-
- crfile := CorefileInput{Contents: []byte(corefile)}
- configs, err := loadConfigs(path.Base(crfile.Path()), bytes.NewReader(crfile.Body()))
- if err != nil {
- return nil, err
- }
- groupings, err := arrangeBindings(configs)
- if err != nil {
- return nil, err
- }
- t.Logf("Starting %d servers", len(groupings))
-
- group := groupings[0]
- s, err := server.New(group.BindAddr.String(), group.Configs, time.Second)
- return s, err
-}
diff --git a/core/coredns.go b/core/coredns.go
new file mode 100644
index 000000000..e770a1a30
--- /dev/null
+++ b/core/coredns.go
@@ -0,0 +1,26 @@
+package core
+
+import (
+ // plug in the server
+ _ "github.com/miekg/coredns/core/dnsserver"
+
+ // plug in the standard directives
+ _ "github.com/miekg/coredns/middleware/bind"
+ _ "github.com/miekg/coredns/middleware/health"
+ _ "github.com/miekg/coredns/middleware/pprof"
+
+ _ "github.com/miekg/coredns/middleware/errors"
+ _ "github.com/miekg/coredns/middleware/loadbalance"
+ _ "github.com/miekg/coredns/middleware/log"
+ _ "github.com/miekg/coredns/middleware/metrics"
+ _ "github.com/miekg/coredns/middleware/rewrite"
+
+ _ "github.com/miekg/coredns/middleware/cache"
+ _ "github.com/miekg/coredns/middleware/chaos"
+ _ "github.com/miekg/coredns/middleware/dnssec"
+ _ "github.com/miekg/coredns/middleware/etcd"
+ _ "github.com/miekg/coredns/middleware/file"
+ _ "github.com/miekg/coredns/middleware/kubernetes"
+ _ "github.com/miekg/coredns/middleware/proxy"
+ _ "github.com/miekg/coredns/middleware/secondary"
+)
diff --git a/core/coremain/run.go b/core/coremain/run.go
deleted file mode 100644
index 407beb1c6..000000000
--- a/core/coremain/run.go
+++ /dev/null
@@ -1,232 +0,0 @@
-package coremain
-
-import (
- "errors"
- "flag"
- "fmt"
- "io/ioutil"
- "log"
- "os"
- "runtime"
- "strconv"
- "strings"
- "time"
-
- "github.com/miekg/coredns/core"
- "github.com/miekg/coredns/core/https"
- "github.com/xenolf/lego/acme"
- "gopkg.in/natefinch/lumberjack.v2"
-)
-
-func init() {
- core.TrapSignals()
- setVersion()
- flag.BoolVar(&https.Agreed, "agree", false, "Agree to Let's Encrypt Subscriber Agreement")
- flag.StringVar(&https.CAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "Certificate authority ACME server")
- flag.StringVar(&conf, "conf", "", "Configuration file to use (default="+core.DefaultConfigFile+")")
- flag.StringVar(&cpu, "cpu", "100%", "CPU cap")
- flag.StringVar(&https.DefaultEmail, "email", "", "Default Let's Encrypt account email address")
- flag.DurationVar(&core.GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown")
- flag.StringVar(&core.Host, "host", core.DefaultHost, "Default host")
- flag.StringVar(&logfile, "log", "", "Process log file")
- flag.StringVar(&core.PidFile, "pidfile", "", "Path to write pid file")
- flag.StringVar(&core.Port, "port", core.DefaultPort, "Default port")
- flag.BoolVar(&core.Quiet, "quiet", false, "Quiet mode (no initialization output)")
- flag.StringVar(&revoke, "revoke", "", "Hostname for which to revoke the certificate")
- flag.StringVar(&core.Root, "root", core.DefaultRoot, "Root path to default zone files")
- flag.BoolVar(&version, "version", false, "Show version")
-}
-
-func Run() {
- flag.Parse() // called here in Run() to allow other packages to set flags in their inits
-
- core.AppName = appName
- core.AppVersion = appVersion
- acme.UserAgent = appName + "/" + appVersion
-
- // set up process log before anything bad happens
- switch logfile {
- case "stdout":
- log.SetOutput(os.Stdout)
- case "stderr":
- log.SetOutput(os.Stderr)
- case "":
- log.SetOutput(ioutil.Discard)
- default:
- log.SetOutput(&lumberjack.Logger{
- Filename: logfile,
- MaxSize: 100,
- MaxAge: 14,
- MaxBackups: 10,
- })
- }
-
- if revoke != "" {
- err := https.Revoke(revoke)
- if err != nil {
- log.Fatal(err)
- }
- fmt.Printf("Revoked certificate for %s\n", revoke)
- os.Exit(0)
- }
- if version {
- fmt.Printf("%s %s\n", appName, appVersion)
- if devBuild && gitShortStat != "" {
- fmt.Printf("%s\n%s\n", gitShortStat, gitFilesModified)
- }
- os.Exit(0)
- }
-
- // Set CPU cap
- err := setCPU(cpu)
- if err != nil {
- mustLogFatal(err)
- }
-
- // Get Corefile input
- corefile, err := core.LoadCorefile(loadCorefile)
- if err != nil {
- mustLogFatal(err)
- }
-
- // Start your engines
- err = core.Start(corefile)
- if err != nil {
- mustLogFatal(err)
- }
-
- // Twiddle your thumbs
- core.Wait()
-}
-
-// mustLogFatal just wraps log.Fatal() in a way that ensures the
-// output is always printed to stderr so the user can see it
-// if the user is still there, even if the process log was not
-// enabled. If this process is a restart, however, and the user
-// might not be there anymore, this just logs to the process log
-// and exits.
-func mustLogFatal(args ...interface{}) {
- if !core.IsRestart() {
- log.SetOutput(os.Stderr)
- }
- log.Fatal(args...)
-}
-
-func loadCorefile() (core.Input, error) {
- // Try -conf flag
- if conf != "" {
- if conf == "stdin" {
- return core.CorefileFromPipe(os.Stdin)
- }
-
- contents, err := ioutil.ReadFile(conf)
- if err != nil {
- return nil, err
- }
-
- return core.CorefileInput{
- Contents: contents,
- Filepath: conf,
- RealFile: true,
- }, nil
- }
-
- // command line args
- if flag.NArg() > 0 {
- confBody := core.Host + ":" + core.Port + "\n" + strings.Join(flag.Args(), "\n")
- return core.CorefileInput{
- Contents: []byte(confBody),
- Filepath: "args",
- }, nil
- }
-
- // Corefile in cwd
- contents, err := ioutil.ReadFile(core.DefaultConfigFile)
- if err != nil {
- if os.IsNotExist(err) {
- return core.DefaultInput(), nil
- }
- return nil, err
- }
- return core.CorefileInput{
- Contents: contents,
- Filepath: core.DefaultConfigFile,
- RealFile: true,
- }, nil
-}
-
-// setCPU parses string cpu and sets GOMAXPROCS
-// according to its value. It accepts either
-// a number (e.g. 3) or a percent (e.g. 50%).
-func setCPU(cpu string) error {
- var numCPU int
-
- availCPU := runtime.NumCPU()
-
- if strings.HasSuffix(cpu, "%") {
- // Percent
- var percent float32
- pctStr := cpu[:len(cpu)-1]
- pctInt, err := strconv.Atoi(pctStr)
- if err != nil || pctInt < 1 || pctInt > 100 {
- return errors.New("invalid CPU value: percentage must be between 1-100")
- }
- percent = float32(pctInt) / 100
- numCPU = int(float32(availCPU) * percent)
- } else {
- // Number
- num, err := strconv.Atoi(cpu)
- if err != nil || num < 1 {
- return errors.New("invalid CPU value: provide a number or percent greater than 0")
- }
- numCPU = num
- }
-
- if numCPU > availCPU {
- numCPU = availCPU
- }
-
- runtime.GOMAXPROCS(numCPU)
- return nil
-}
-
-// setVersion figures out the version information based on
-// variables set by -ldflags.
-func setVersion() {
- // A development build is one that's not at a tag or has uncommitted changes
- devBuild = gitTag == "" || gitShortStat != ""
-
- // Only set the appVersion if -ldflags was used
- if gitNearestTag != "" || gitTag != "" {
- if devBuild && gitNearestTag != "" {
- appVersion = fmt.Sprintf("%s (+%s %s)",
- strings.TrimPrefix(gitNearestTag, "v"), gitCommit, buildDate)
- } else if gitTag != "" {
- appVersion = strings.TrimPrefix(gitTag, "v")
- }
- }
-}
-
-const appName = "CoreDNS"
-
-// Flags that control program flow or startup
-var (
- conf string
- cpu string
- logfile string
- revoke string
- version bool
-)
-
-// Build information obtained with the help of -ldflags
-var (
- appVersion = "(untracked dev build)" // inferred at startup
- devBuild = true // inferred at startup
-
- buildDate string // date -u
- gitTag string // git describe --exact-match HEAD 2> /dev/null
- gitNearestTag string // git describe --abbrev=0 --tags HEAD
- gitCommit string // git rev-parse HEAD
- gitShortStat string // git diff-index --shortstat
- gitFilesModified string // git diff-index --name-only HEAD
-)
diff --git a/core/coremain/run_test.go b/core/coremain/run_test.go
deleted file mode 100644
index 149dab0c1..000000000
--- a/core/coremain/run_test.go
+++ /dev/null
@@ -1,75 +0,0 @@
-package coremain
-
-import (
- "runtime"
- "testing"
-)
-
-func TestSetCPU(t *testing.T) {
- currentCPU := runtime.GOMAXPROCS(-1)
- maxCPU := runtime.NumCPU()
- halfCPU := int(0.5 * float32(maxCPU))
- if halfCPU < 1 {
- halfCPU = 1
- }
- for i, test := range []struct {
- input string
- output int
- shouldErr bool
- }{
- {"1", 1, false},
- {"-1", currentCPU, true},
- {"0", currentCPU, true},
- {"100%", maxCPU, false},
- {"50%", halfCPU, false},
- {"110%", currentCPU, true},
- {"-10%", currentCPU, true},
- {"invalid input", currentCPU, true},
- {"invalid input%", currentCPU, true},
- {"9999", maxCPU, false}, // over available CPU
- } {
- err := setCPU(test.input)
- if test.shouldErr && err == nil {
- t.Errorf("Test %d: Expected error, but there wasn't any", i)
- }
- if !test.shouldErr && err != nil {
- t.Errorf("Test %d: Expected no error, but there was one: %v", i, err)
- }
- if actual, expected := runtime.GOMAXPROCS(-1), test.output; actual != expected {
- t.Errorf("Test %d: GOMAXPROCS was %d but expected %d", i, actual, expected)
- }
- // teardown
- runtime.GOMAXPROCS(currentCPU)
- }
-}
-
-func TestSetVersion(t *testing.T) {
- setVersion()
- if !devBuild {
- t.Error("Expected default to assume development build, but it didn't")
- }
- if got, want := appVersion, "(untracked dev build)"; got != want {
- t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
- }
-
- gitTag = "v1.1"
- setVersion()
- if devBuild {
- t.Error("Expected a stable build if gitTag is set with no changes")
- }
- if got, want := appVersion, "1.1"; got != want {
- t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
- }
-
- gitTag = ""
- gitNearestTag = "v1.0"
- gitCommit = "deadbeef"
- buildDate = "Fri Feb 26 06:53:17 UTC 2016"
- setVersion()
- if !devBuild {
- t.Error("Expected inferring a dev build when gitTag is empty")
- }
- if got, want := appVersion, "1.0 (+deadbeef Fri Feb 26 06:53:17 UTC 2016)"; got != want {
- t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
- }
-}
diff --git a/core/directives.go b/core/directives.go
deleted file mode 100644
index 63e245578..000000000
--- a/core/directives.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package core
-
-import (
- "github.com/miekg/coredns/core/https"
- "github.com/miekg/coredns/core/parse"
- "github.com/miekg/coredns/core/setup"
- "github.com/miekg/coredns/middleware"
-)
-
-func init() {
- // The parse package must know which directives
- // are valid, but it must not import the setup
- // or config package. To solve this problem, we
- // fill up this map in our init function here.
- // The parse package does not need to know the
- // ordering of the directives.
- for _, dir := range directiveOrder {
- parse.ValidDirectives[dir.name] = struct{}{}
- }
-}
-
-// Directives are registered in the order they should be
-// executed. Middleware (directives that inject a handler)
-// are executed in the order A-B-C-*-C-B-A, assuming
-// they all call the Next handler in the chain.
-//
-// Ordering is VERY important. Every middleware will
-// feel the effects of all other middleware below
-// (after) them during a request, but they must not
-// care what middleware above them are doing.
-//
-// For example, log needs to know the status code and
-// exactly how many bytes were written to the client,
-// which every other middleware can affect, so it gets
-// registered first. The errors middleware does not
-// care if gzip or log modifies its response, so it
-// gets registered below them. Gzip, on the other hand,
-// DOES care what errors does to the response since it
-// must compress every output to the client, even error
-// pages, so it must be registered before the errors
-// middleware and any others that would write to the
-// response.
-var directiveOrder = []directive{
- // Essential directives that initialize vital configuration settings
- {"root", setup.Root},
- {"bind", setup.BindHost},
- {"tls", https.Setup},
- {"health", setup.Health},
- {"pprof", setup.PProf},
-
- // Other directives that don't create HTTP handlers
- {"startup", setup.Startup},
- {"shutdown", setup.Shutdown},
-
- // Directives that inject handlers (middleware)
- {"prometheus", setup.Prometheus},
- {"errors", setup.Errors},
- {"log", setup.Log},
-
- {"chaos", setup.Chaos},
- {"rewrite", setup.Rewrite},
- {"loadbalance", setup.Loadbalance},
- {"cache", setup.Cache},
- {"dnssec", setup.Dnssec},
- {"file", setup.File},
- {"secondary", setup.Secondary},
- {"etcd", setup.Etcd},
- {"kubernetes", setup.Kubernetes},
- {"proxy", setup.Proxy},
-}
-
-// RegisterDirective adds the given directive to CoreDNS's list of directives.
-// Pass the name of a directive you want it to be placed after,
-// otherwise it will be placed at the bottom of the stack.
-func RegisterDirective(name string, setup SetupFunc, after string) {
- dir := directive{name: name, setup: setup}
- idx := len(directiveOrder)
- for i := range directiveOrder {
- if directiveOrder[i].name == after {
- idx = i + 1
- break
- }
- }
- newDirectives := append(directiveOrder[:idx], append([]directive{dir}, directiveOrder[idx:]...)...)
- directiveOrder = newDirectives
- parse.ValidDirectives[name] = struct{}{}
-}
-
-// directive ties together a directive name with its setup function.
-type directive struct {
- name string
- setup SetupFunc
-}
-
-// SetupFunc takes a controller and may optionally return a middleware.
-// If the resulting middleware is not nil, it will be chained into
-// the DNS handlers in the order specified in this package.
-type SetupFunc func(c *setup.Controller) (middleware.Middleware, error)
diff --git a/core/directives_test.go b/core/directives_test.go
deleted file mode 100644
index 1bee144f5..000000000
--- a/core/directives_test.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package core
-
-import (
- "reflect"
- "testing"
-)
-
-func TestRegister(t *testing.T) {
- directives := []directive{
- {"dummy", nil},
- {"dummy2", nil},
- }
- directiveOrder = directives
- RegisterDirective("foo", nil, "dummy")
- if len(directiveOrder) != 3 {
- t.Fatal("Should have 3 directives now")
- }
- getNames := func() (s []string) {
- for _, d := range directiveOrder {
- s = append(s, d.name)
- }
- return s
- }
- if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2"}) {
- t.Fatalf("directive order doesn't match: %s", getNames())
- }
- RegisterDirective("bar", nil, "ASDASD")
- if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2", "bar"}) {
- t.Fatalf("directive order doesn't match: %s", getNames())
- }
-}
diff --git a/core/dns/storage.go b/core/dns/storage.go
deleted file mode 100644
index 0c8e68437..000000000
--- a/core/dns/storage.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package dns
-
-import (
- "path/filepath"
-
- "github.com/miekg/coredns/core/assets"
-)
-
-var storage = Storage(assets.Path())
-
-// Storage is a root directory and facilitates
-// forming file paths derived from it.
-type Storage string
-
-// Zones gets the directory that stores zones data.
-func (s Storage) Zones() string {
- return filepath.Join(string(s), "zones")
-}
-
-// Zone returns the path to the folder containing assets for domain.
-func (s Storage) Zone(domain string) string {
- return filepath.Join(s.Zones(), domain)
-}
-
-// SecondaryZoneFile returns the path to domain's secondary zone file (when fetched).
-func (s Storage) SecondaryZoneFile(domain string) string {
- return filepath.Join(s.Zone(domain), "db."+domain)
-}
diff --git a/core/dns/storage_test.go b/core/dns/storage_test.go
deleted file mode 100644
index 859435a15..000000000
--- a/core/dns/storage_test.go
+++ /dev/null
@@ -1,20 +0,0 @@
-package dns
-
-import (
- "path/filepath"
- "testing"
-)
-
-func TestStorage(t *testing.T) {
- storage = Storage("./le_test")
-
- if expected, actual := filepath.Join("le_test", "zones"), storage.Zones(); actual != expected {
- t.Errorf("Expected Zones() to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "zones", "test.com"), storage.Zone("test.com"); actual != expected {
- t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "zones", "test.com", "db.test.com"), storage.SecondaryZoneFile("test.com"); actual != expected {
- t.Errorf("Expected SecondaryZoneFile() to return '%s' but got '%s'", expected, actual)
- }
-}
diff --git a/core/dnsserver/address.go b/core/dnsserver/address.go
new file mode 100644
index 000000000..865d082cc
--- /dev/null
+++ b/core/dnsserver/address.go
@@ -0,0 +1,44 @@
+package dnsserver
+
+import (
+ "fmt"
+ "net"
+ "strings"
+
+ "github.com/miekg/dns"
+)
+
+type zoneAddr struct {
+ Zone string
+ Port string
+}
+
+// String return z.Zone + ":" + z.Port as a string.
+func (z zoneAddr) String() string { return z.Zone + ":" + z.Port }
+
+// normalizeZone parses an zone string into a structured format with separate
+// host, and port portions, as well as the original input string.
+func normalizeZone(str string) (zoneAddr, error) {
+ var err error
+
+ // separate host and port
+ host, port, err := net.SplitHostPort(str)
+ if err != nil {
+ host, port, err = net.SplitHostPort(str + ":")
+ // no error check here; return err at end of function
+ }
+
+ if len(host) > 255 {
+ return zoneAddr{}, fmt.Errorf("specified zone is too long: %d > 255", len(host))
+ }
+ _, d := dns.IsDomainName(host)
+ if !d {
+ return zoneAddr{}, fmt.Errorf("zone is not a valid domain name: %s", host)
+ }
+
+ if port == "" {
+ port = "53"
+ }
+
+ return zoneAddr{Zone: strings.ToLower(dns.Fqdn(host)), Port: port}, err
+}
diff --git a/core/dnsserver/config.go b/core/dnsserver/config.go
new file mode 100644
index 000000000..7af483f21
--- /dev/null
+++ b/core/dnsserver/config.go
@@ -0,0 +1,38 @@
+package dnsserver
+
+import "github.com/mholt/caddy"
+
+// Config configuration for a single server.
+type Config struct {
+ // The zone of the site.
+ Zone string
+
+ // The hostname to bind listener to, defaults to the wildcard address
+ ListenHost string
+
+ // The port to listen on.
+ Port string
+
+ // The directory from which to parse db files, and store keys.
+ Root string
+
+ // Middleware stack.
+ Middleware []Middleware
+
+ // Compiled middleware stack.
+ middlewareChain Handler
+}
+
+// GetConfig gets the Config that corresponds to c.
+// If none exist nil is returned.
+func GetConfig(c *caddy.Controller) *Config {
+ ctx := c.Context().(*dnsContext)
+ if cfg, ok := ctx.keysToConfigs[c.Key]; ok {
+ return cfg
+ }
+ // we should only get here during tests because directive
+ // actions typically skip the server blocks where we make
+ // the configs.
+ ctx.saveConfig(c.Key, &Config{Root: Root})
+ return GetConfig(c)
+}
diff --git a/core/dnsserver/directives.go b/core/dnsserver/directives.go
new file mode 100644
index 000000000..78a8a11f7
--- /dev/null
+++ b/core/dnsserver/directives.go
@@ -0,0 +1,32 @@
+package dnsserver
+
+// Add here, and in core/coredns.go to use them.
+
+// Directives are registered in the order they should be
+// executed.
+//
+// Ordering is VERY important. Every middleware will
+// feel the effects of all other middleware below
+// (after) them during a request, but they must not
+// care what middleware above them are doing.
+var Directives = []string{
+ "bind",
+ "health",
+ "pprof",
+
+ "prometheus",
+ "errors",
+ "log",
+ "chaos",
+ "cache",
+
+ "rewrite",
+ "loadbalance",
+
+ "dnssec",
+ "file",
+ "secondary",
+ "etcd",
+ "kubernetes",
+ "proxy",
+}
diff --git a/core/dnsserver/middleware.go b/core/dnsserver/middleware.go
new file mode 100644
index 000000000..5bce304b1
--- /dev/null
+++ b/core/dnsserver/middleware.go
@@ -0,0 +1,52 @@
+package dnsserver
+
+import (
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
+)
+
+type (
+ // Middleware is the middle layer which represents the traditional
+ // idea of middleware: it chains one Handler to the next by being
+ // passed the next Handler in the chain.
+ Middleware func(Handler) Handler
+
+ // Handler is like dns.Handler except ServeDNS may return an rcode
+ // and/or error.
+ //
+ // If ServeDNS writes to the response body, it should return a status
+ // code. If the status code is not one of the following:
+ // * SERVFAIL (dns.RcodeServerFailure)
+ // * REFUSED (dns.RecodeRefused)
+ // * FORMERR (dns.RcodeFormatError)
+ // * NOTIMP (dns.RcodeNotImplemented)
+ //
+ // CoreDNS assumes *no* reply has yet been written. All other response
+ // codes signal other handlers above it that the response message is
+ // already written, and that they should not write to it also.
+ //
+ // If ServeDNS encounters an error, it should return the error value
+ // so it can be logged by designated error-handling middleware.
+ //
+ // If writing a response after calling another ServeDNS method, the
+ // returned rcode SHOULD be used when writing the response.
+ //
+ // If handling errors after calling another ServeDNS method, the
+ // returned error value SHOULD be logged or handled accordingly.
+ //
+ // Otherwise, return values should be propagated down the middleware
+ // chain by returning them unchanged.
+ Handler interface {
+ ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
+ }
+
+ // HandlerFunc is a convenience type like dns.HandlerFunc, except
+ // ServeDNS returns an rcode and an error. See Handler
+ // documentation for more information.
+ HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
+)
+
+// ServeDNS implements the Handler interface.
+func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ return f(ctx, w, r)
+}
diff --git a/core/dnsserver/register.go b/core/dnsserver/register.go
new file mode 100644
index 000000000..3c12c019c
--- /dev/null
+++ b/core/dnsserver/register.go
@@ -0,0 +1,156 @@
+package dnsserver
+
+import (
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/mholt/caddy"
+ "github.com/mholt/caddy/caddyfile"
+)
+
+const serverType = "dns"
+
+func init() {
+ caddy.RegisterServerType(serverType, caddy.ServerType{
+ Directives: Directives,
+ DefaultInput: func() caddy.Input {
+ if Port == DefaultPort && Zone != "" {
+ return caddy.CaddyfileInput{
+ Filepath: "Corefile",
+ Contents: nil,
+ ServerTypeName: serverType,
+ }
+ }
+ return caddy.CaddyfileInput{
+ Filepath: "Corefile",
+ Contents: nil,
+ ServerTypeName: serverType,
+ }
+ },
+ NewContext: newContext,
+ })
+}
+
+var TestNewContext = newContext
+
+func newContext() caddy.Context {
+ return &dnsContext{keysToConfigs: make(map[string]*Config)}
+}
+
+type dnsContext struct {
+ keysToConfigs map[string]*Config
+
+ // configs is the master list of all site configs.
+ configs []*Config
+}
+
+func (h *dnsContext) saveConfig(key string, cfg *Config) {
+ h.configs = append(h.configs, cfg)
+ h.keysToConfigs[key] = cfg
+}
+
+// InspectServerBlocks make sure that everything checks out before
+// executing directives and otherwise prepares the directives to
+// be parsed and executed.
+func (h *dnsContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
+ // Normalize and check all the zone names and check for duplicates
+ dups := map[string]string{}
+ for _, s := range serverBlocks {
+ for i, k := range s.Keys {
+ za, err := normalizeZone(k)
+ if err != nil {
+ return nil, err
+ }
+ s.Keys[i] = za.String()
+ if v, ok := dups[za.Zone]; ok {
+ return nil, fmt.Errorf("cannot serve %s - zone already defined for %v", za, v)
+
+ }
+ dups[za.Zone] = za.String()
+
+ // Save the config to our master list, and key it for lookups
+ cfg := &Config{
+ Zone: za.Zone,
+ Port: za.Port,
+ // TODO(miek): more?
+ }
+ h.saveConfig(za.String(), cfg)
+ }
+ }
+ return serverBlocks, nil
+}
+
+// MakeServers uses the newly-created siteConfigs to create and return a list of server instances.
+func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
+
+ // we must map (group) each config to a bind address
+ groups, err := groupConfigsByListenAddr(h.configs)
+ if err != nil {
+ return nil, err
+ }
+ // then we create a server for each group
+ var servers []caddy.Server
+ for addr, group := range groups {
+ s, err := NewServer(addr, group)
+ if err != nil {
+ return nil, err
+ }
+ servers = append(servers, s)
+ }
+
+ return servers, nil
+}
+
+// AddMiddleware adds a middleware to a site's middleware stack.
+func (sc *Config) AddMiddleware(m Middleware) {
+ sc.Middleware = append(sc.Middleware, m)
+}
+
+// groupSiteConfigsByListenAddr groups site configs by their listen
+// (bind) address, so sites that use the same listener can be served
+// on the same server instance. The return value maps the listen
+// address (what you pass into net.Listen) to the list of site configs.
+// This function does NOT vet the configs to ensure they are compatible.
+func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) {
+ groups := make(map[string][]*Config)
+
+ for _, conf := range configs {
+ if conf.Port == "" {
+ conf.Port = Port
+ }
+ addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.ListenHost, conf.Port))
+ if err != nil {
+ return nil, err
+ }
+ addrstr := addr.String()
+ groups[addrstr] = append(groups[addrstr], conf)
+ }
+
+ return groups, nil
+}
+
+const (
+ // DefaultZone is the default zone.
+ DefaultZone = "."
+ // DefaultPort is the default port.
+ DefaultPort = "2053"
+ // DefaultRoot is the default root folder.
+ DefaultRoot = "."
+)
+
+// These "soft defaults" are configurable by
+// command line flags, etc.
+var (
+ // Root is the site root
+ Root = DefaultRoot
+
+ // Host is the site host
+ Zone = DefaultZone
+
+ // Port is the site port
+ Port = DefaultPort
+
+ // GracefulTimeout is the maximum duration of a graceful shutdown.
+ GracefulTimeout time.Duration
+)
diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go
new file mode 100644
index 000000000..27c62312b
--- /dev/null
+++ b/core/dnsserver/server.go
@@ -0,0 +1,254 @@
+package dnsserver
+
+import (
+ "log"
+ "net"
+ "runtime"
+ "sync"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
+)
+
+// Server represents an instance of a server, which serves
+// DNS requests at a particular address (host and port). A
+// server is capable of serving numerous zones on
+// the same address and the listener may be stopped for
+// graceful termination (POSIX only).
+type Server struct {
+ Addr string // Address we listen on
+ mux *dns.ServeMux
+ server [2]*dns.Server // 0 is a net.Listener, 1 is a net.PacketConn (a *UDPConn) in our case.
+
+ l net.Listener
+ p net.PacketConn
+ m sync.Mutex // protects listener and packetconn
+
+ zones map[string]*Config // zones keyed by their address
+ dnsWg sync.WaitGroup // used to wait on outstanding connections
+ connTimeout time.Duration // the maximum duration of a graceful shutdown
+}
+
+func NewServer(addr string, group []*Config) (*Server, error) {
+
+ s := &Server{
+ Addr: addr,
+ zones: make(map[string]*Config),
+ connTimeout: 5 * time.Second, // TODO(miek): was configurable
+ }
+ mux := dns.NewServeMux()
+ mux.Handle(".", s) // wildcard handler, everything will go through here
+ s.mux = mux
+
+ // We have to bound our wg with one increment
+ // to prevent a "race condition" that is hard-coded
+ // into sync.WaitGroup.Wait() - basically, an add
+ // with a positive delta must be guaranteed to
+ // occur before Wait() is called on the wg.
+ // In a way, this kind of acts as a safety barrier.
+ s.dnsWg.Add(1)
+
+ for _, site := range group {
+ // set the config per zone
+ s.zones[site.Zone] = site
+ // compile custom middleware for everything
+ var stack Handler
+ for i := len(site.Middleware) - 1; i >= 0; i-- {
+ stack = site.Middleware[i](stack)
+ }
+ site.middlewareChain = stack
+ }
+
+ return s, nil
+}
+
+// LocalAddr return the addresses where the server is bound to.
+func (s *Server) LocalAddr() net.Addr {
+ s.m.Lock()
+ defer s.m.Unlock()
+ return s.l.Addr()
+}
+
+// LocalAddrPacket return the net.PacketConn address where the server is bound to.
+func (s *Server) LocalAddrPacket() net.Addr {
+ s.m.Lock()
+ defer s.m.Unlock()
+ return s.p.LocalAddr()
+}
+
+// Serve starts the server with an existing listener. It blocks until the server stops.
+func (s *Server) Serve(l net.Listener) error {
+ s.m.Lock()
+ s.server[tcp] = &dns.Server{Listener: l, Net: "tcp", Handler: s.mux}
+ s.m.Unlock()
+
+ return s.server[tcp].ActivateAndServe()
+}
+
+// ServePacket starts the server with an existing packetconn. It blocks until the server stops.
+func (s *Server) ServePacket(p net.PacketConn) error {
+ s.m.Lock()
+ s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: s.mux}
+ s.m.Unlock()
+
+ return s.server[udp].ActivateAndServe()
+}
+
+func (s *Server) Listen() (net.Listener, error) {
+ l, err := net.Listen("tcp", s.Addr)
+ if err != nil {
+ return nil, err
+ }
+ s.m.Lock()
+ s.l = l
+ s.m.Unlock()
+ return l, nil
+}
+
+func (s *Server) ListenPacket() (net.PacketConn, error) {
+ p, err := net.ListenPacket("udp", s.Addr)
+ if err != nil {
+ return nil, err
+ }
+
+ s.m.Lock()
+ s.p = p
+ s.m.Unlock()
+ return p, nil
+}
+
+// Stop stops the server. It blocks until the server is
+// totally stopped. On POSIX systems, it will wait for
+// connections to close (up to a max timeout of a few
+// seconds); on Windows it will close the listener
+// immediately.
+func (s *Server) Stop() (err error) {
+
+ if runtime.GOOS != "windows" {
+ // force connections to close after timeout
+ done := make(chan struct{})
+ go func() {
+ s.dnsWg.Done() // decrement our initial increment used as a barrier
+ s.dnsWg.Wait()
+ close(done)
+ }()
+
+ // Wait for remaining connections to finish or
+ // force them all to close after timeout
+ select {
+ case <-time.After(s.connTimeout):
+ case <-done:
+ }
+ }
+
+ // Close the listener now; this stops the server without delay
+ s.m.Lock()
+ if s.l != nil {
+ err = s.l.Close()
+ }
+ if s.p != nil {
+ err = s.p.Close()
+ }
+
+ for _, s1 := range s.server {
+ err = s1.Shutdown()
+ }
+ s.m.Unlock()
+ return
+}
+
+// ServeDNS is the entry point for every request to the address that s
+// is bound to. It acts as a multiplexer for the requests zonename as
+// defined in the request so that the correct zone
+// (configuration and middleware stack) will handle the request.
+func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+ // TODO(miek): expensive to use defer
+ defer func() {
+ // In case the user doesn't enable error middleware, we still
+ // need to make sure that we stay alive up here
+ if rec := recover(); rec != nil {
+ DefaultErrorFunc(w, r, dns.RcodeServerFailure)
+ }
+ }()
+
+ if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once.
+ w.WriteMsg(m)
+ return
+ }
+
+ q := r.Question[0].Name
+ b := make([]byte, len(q))
+ off, end := 0, false
+ ctx := context.Background()
+
+ for {
+ l := len(q[off:])
+ for i := 0; i < l; i++ {
+ b[i] = q[off+i]
+ // normalize the name for the lookup
+ if b[i] >= 'A' && b[i] <= 'Z' {
+ b[i] |= ('a' - 'A')
+ }
+ }
+
+ if h, ok := s.zones[string(b[:l])]; ok {
+ if r.Question[0].Qtype != dns.TypeDS {
+ rcode, _ := h.middlewareChain.ServeDNS(ctx, w, r)
+ if RcodeNoClientWrite(rcode) {
+ DefaultErrorFunc(w, r, rcode)
+ }
+ return
+ }
+ }
+ off, end = dns.NextLabel(q, off)
+ if end {
+ break
+ }
+ }
+ // Wildcard match, if we have found nothing try the root zone as a last resort.
+ if h, ok := s.zones["."]; ok {
+ rcode, _ := h.middlewareChain.ServeDNS(ctx, w, r)
+ if RcodeNoClientWrite(rcode) {
+ DefaultErrorFunc(w, r, rcode)
+ }
+ return
+ }
+
+ // Still here? Error out with REFUSED and some logging
+ remoteHost := w.RemoteAddr().String()
+ DefaultErrorFunc(w, r, dns.RcodeRefused)
+ log.Printf("[INFO] \"%s %s %s\" - No such zone at %s (Remote: %s)", dns.Type(r.Question[0].Qtype), dns.Class(r.Question[0].Qclass), q, s.Addr, remoteHost)
+}
+
+// DefaultErrorFunc responds to an DNS request with an error.
+func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
+ state := middleware.State{W: w, Req: r}
+
+ answer := new(dns.Msg)
+ answer.SetRcode(r, rcode)
+ state.SizeAndDo(answer)
+
+ w.WriteMsg(answer)
+}
+
+func RcodeNoClientWrite(rcode int) bool {
+ switch rcode {
+ case dns.RcodeServerFailure:
+ fallthrough
+ case dns.RcodeRefused:
+ fallthrough
+ case dns.RcodeFormatError:
+ fallthrough
+ case dns.RcodeNotImplemented:
+ return true
+ }
+ return false
+}
+
+const (
+ tcp = 0
+ udp = 1
+)
diff --git a/core/helpers.go b/core/helpers.go
deleted file mode 100644
index 82a0150af..000000000
--- a/core/helpers.go
+++ /dev/null
@@ -1,102 +0,0 @@
-package core
-
-import (
- "bytes"
- "fmt"
- "io/ioutil"
- "log"
- "os"
- "os/exec"
- "runtime"
- "strconv"
- "strings"
- "sync"
-)
-
-// isLocalhost returns true if host looks explicitly like a localhost address.
-func isLocalhost(host string) bool {
- return host == "localhost" || host == "::1" || strings.HasPrefix(host, "127.")
-}
-
-// checkFdlimit issues a warning if the OS max file descriptors is below a recommended minimum.
-func checkFdlimit() {
- const min = 4096
-
- // Warn if ulimit is too low for production sites
- if runtime.GOOS == "linux" || runtime.GOOS == "darwin" {
- out, err := exec.Command("sh", "-c", "ulimit -n").Output() // use sh because ulimit isn't in Linux $PATH
- if err == nil {
- // Note that an error here need not be reported
- lim, err := strconv.Atoi(string(bytes.TrimSpace(out)))
- if err == nil && lim < min {
- fmt.Printf("Warning: File descriptor limit %d is too low for production sites. At least %d is recommended. Set with \"ulimit -n %d\".\n", lim, min, min)
- }
- }
- }
-}
-
-// signalSuccessToParent tells the parent our status using pipe at index 3.
-// If this process is not a restart, this function does nothing.
-// Calling this function once this process has successfully initialized
-// is vital so that the parent process can unblock and kill itself.
-// This function is idempotent; it executes at most once per process.
-func signalSuccessToParent() {
- signalParentOnce.Do(func() {
- if IsRestart() {
- ppipe := os.NewFile(3, "") // parent is reading from pipe at index 3
- _, err := ppipe.Write([]byte("success")) // we must send some bytes to the parent
- if err != nil {
- log.Printf("[ERROR] Communicating successful init to parent: %v", err)
- }
- ppipe.Close()
- }
- })
-}
-
-// signalParentOnce is used to make sure that the parent is only
-// signaled once; doing so more than once breaks whatever socket is
-// at fd 4 (the reason for this is still unclear - to reproduce,
-// call Stop() and Start() in succession at least once after a
-// restart, then try loading first host of Corefile in the browser).
-// Do not use this directly - call signalSuccessToParent instead.
-var signalParentOnce sync.Once
-
-// corefileGob maps bind address to index of the file descriptor
-// in the Files array passed to the child process. It also contains
-// the corefile contents and other state needed by the new process.
-// Used only during graceful restarts where a new process is spawned.
-type corefileGob struct {
- ListenerFds map[string]uintptr
- Corefile Input
- OnDemandTLSCertsIssued int32
-}
-
-// IsRestart returns whether this process is, according
-// to env variables, a fork as part of a graceful restart.
-func IsRestart() bool {
- return os.Getenv("COREDNS_RESTART") == "true"
-}
-
-// writePidFile writes the process ID to the file at PidFile, if specified.
-func writePidFile() error {
- pid := []byte(strconv.Itoa(os.Getpid()) + "\n")
- return ioutil.WriteFile(PidFile, pid, 0644)
-}
-
-// CorefileInput represents a Corefile as input
-// and is simply a convenient way to implement
-// the Input interface.
-type CorefileInput struct {
- Filepath string
- Contents []byte
- RealFile bool
-}
-
-// Body returns c.Contents.
-func (c CorefileInput) Body() []byte { return c.Contents }
-
-// Path returns c.Filepath.
-func (c CorefileInput) Path() string { return c.Filepath }
-
-// IsFile returns true if the original input was a real file on the file system.
-func (c CorefileInput) IsFile() bool { return c.RealFile }
diff --git a/core/https/certificates.go b/core/https/certificates.go
deleted file mode 100644
index 6a8f3adc6..000000000
--- a/core/https/certificates.go
+++ /dev/null
@@ -1,234 +0,0 @@
-package https
-
-import (
- "crypto/tls"
- "crypto/x509"
- "errors"
- "io/ioutil"
- "log"
- "strings"
- "sync"
- "time"
-
- "github.com/xenolf/lego/acme"
- "golang.org/x/crypto/ocsp"
-)
-
-// certCache stores certificates in memory,
-// keying certificates by name.
-var certCache = make(map[string]Certificate)
-var certCacheMu sync.RWMutex
-
-// Certificate is a tls.Certificate with associated metadata tacked on.
-// Even if the metadata can be obtained by parsing the certificate,
-// we can be more efficient by extracting the metadata once so it's
-// just there, ready to use.
-type Certificate struct {
- tls.Certificate
-
- // Names is the list of names this certificate is written for.
- // The first is the CommonName (if any), the rest are SAN.
- Names []string
-
- // NotAfter is when the certificate expires.
- NotAfter time.Time
-
- // Managed certificates are certificates that CoreDNS is managing,
- // as opposed to the user specifying a certificate and key file
- // or directory and managing the certificate resources themselves.
- Managed bool
-
- // OnDemand certificates are obtained or loaded on-demand during TLS
- // handshakes (as opposed to preloaded certificates, which are loaded
- // at startup). If OnDemand is true, Managed must necessarily be true.
- // OnDemand certificates are maintained in the background just like
- // preloaded ones, however, if an OnDemand certificate fails to renew,
- // it is removed from the in-memory cache.
- OnDemand bool
-
- // OCSP contains the certificate's parsed OCSP response.
- OCSP *ocsp.Response
-}
-
-// getCertificate gets a certificate that matches name (a server name)
-// from the in-memory cache. If there is no exact match for name, it
-// will be checked against names of the form '*.example.com' (wildcard
-// certificates) according to RFC 6125. If a match is found, matched will
-// be true. If no matches are found, matched will be false and a default
-// certificate will be returned with defaulted set to true. If no default
-// certificate is set, defaulted will be set to false.
-//
-// The logic in this function is adapted from the Go standard library,
-// which is by the Go Authors.
-//
-// This function is safe for concurrent use.
-func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
- var ok bool
-
- // Not going to trim trailing dots here since RFC 3546 says,
- // "The hostname is represented ... without a trailing dot."
- // Just normalize to lowercase.
- name = strings.ToLower(name)
-
- certCacheMu.RLock()
- defer certCacheMu.RUnlock()
-
- // exact match? great, let's use it
- if cert, ok = certCache[name]; ok {
- matched = true
- return
- }
-
- // try replacing labels in the name with wildcards until we get a match
- labels := strings.Split(name, ".")
- for i := range labels {
- labels[i] = "*"
- candidate := strings.Join(labels, ".")
- if cert, ok = certCache[candidate]; ok {
- matched = true
- return
- }
- }
-
- // if nothing matches, use the default certificate or bust
- cert, defaulted = certCache[""]
- return
-}
-
-// cacheManagedCertificate loads the certificate for domain into the
-// cache, flagging it as Managed and, if onDemand is true, as OnDemand
-// (meaning that it was obtained or loaded during a TLS handshake).
-//
-// This function is safe for concurrent use.
-func cacheManagedCertificate(domain string, onDemand bool) (Certificate, error) {
- cert, err := makeCertificateFromDisk(storage.SiteCertFile(domain), storage.SiteKeyFile(domain))
- if err != nil {
- return cert, err
- }
- cert.Managed = true
- cert.OnDemand = onDemand
- cacheCertificate(cert)
- return cert, nil
-}
-
-// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile
-// and keyFile, which must be in PEM format. It stores the certificate in
-// memory. The Managed and OnDemand flags of the certificate will be set to
-// false.
-//
-// This function is safe for concurrent use.
-func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
- cert, err := makeCertificateFromDisk(certFile, keyFile)
- if err != nil {
- return err
- }
- cacheCertificate(cert)
- return nil
-}
-
-// cacheUnmanagedCertificatePEMBytes makes a certificate out of the PEM bytes
-// of the certificate and key, then caches it in memory.
-//
-// This function is safe for concurrent use.
-func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
- cert, err := makeCertificate(certBytes, keyBytes)
- if err != nil {
- return err
- }
- cacheCertificate(cert)
- return nil
-}
-
-// makeCertificateFromDisk makes a Certificate by loading the
-// certificate and key files. It fills out all the fields in
-// the certificate except for the Managed and OnDemand flags.
-// (It is up to the caller to set those.)
-func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
- certPEMBlock, err := ioutil.ReadFile(certFile)
- if err != nil {
- return Certificate{}, err
- }
- keyPEMBlock, err := ioutil.ReadFile(keyFile)
- if err != nil {
- return Certificate{}, err
- }
- return makeCertificate(certPEMBlock, keyPEMBlock)
-}
-
-// makeCertificate turns a certificate PEM bundle and a key PEM block into
-// a Certificate, with OCSP and other relevant metadata tagged with it,
-// except for the OnDemand and Managed flags. It is up to the caller to
-// set those properties.
-func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
- var cert Certificate
-
- // Convert to a tls.Certificate
- tlsCert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
- if err != nil {
- return cert, err
- }
- if len(tlsCert.Certificate) == 0 {
- return cert, errors.New("certificate is empty")
- }
-
- // Parse leaf certificate and extract relevant metadata
- leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
- if err != nil {
- return cert, err
- }
- if leaf.Subject.CommonName != "" {
- cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)}
- }
- for _, name := range leaf.DNSNames {
- if name != leaf.Subject.CommonName {
- cert.Names = append(cert.Names, strings.ToLower(name))
- }
- }
- cert.NotAfter = leaf.NotAfter
-
- // Staple OCSP
- ocspBytes, ocspResp, err := acme.GetOCSPForCert(certPEMBlock)
- if err != nil {
- // An error here is not a problem because a certificate may simply
- // not contain a link to an OCSP server. But we should log it anyway.
- log.Printf("[WARNING] No OCSP stapling for %v: %v", cert.Names, err)
- } else if ocspResp.Status == ocsp.Good {
- tlsCert.OCSPStaple = ocspBytes
- cert.OCSP = ocspResp
- }
-
- cert.Certificate = tlsCert
- return cert, nil
-}
-
-// cacheCertificate adds cert to the in-memory cache. If the cache is
-// empty, cert will be used as the default certificate. If the cache is
-// full, random entries are deleted until there is room to map all the
-// names on the certificate.
-//
-// This certificate will be keyed to the names in cert.Names. Any name
-// that is already a key in the cache will be replaced with this cert.
-//
-// This function is safe for concurrent use.
-func cacheCertificate(cert Certificate) {
- certCacheMu.Lock()
- if _, ok := certCache[""]; !ok {
- // use as default
- cert.Names = append(cert.Names, "")
- certCache[""] = cert
- }
- for len(certCache)+len(cert.Names) > 10000 {
- // for simplicity, just remove random elements
- for key := range certCache {
- if key == "" { // ... but not the default cert
- continue
- }
- delete(certCache, key)
- break
- }
- }
- for _, name := range cert.Names {
- certCache[name] = cert
- }
- certCacheMu.Unlock()
-}
diff --git a/core/https/certificates_test.go b/core/https/certificates_test.go
deleted file mode 100644
index dbfb4efc1..000000000
--- a/core/https/certificates_test.go
+++ /dev/null
@@ -1,59 +0,0 @@
-package https
-
-import "testing"
-
-func TestUnexportedGetCertificate(t *testing.T) {
- defer func() { certCache = make(map[string]Certificate) }()
-
- // When cache is empty
- if _, matched, defaulted := getCertificate("example.com"); matched || defaulted {
- t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
- }
-
- // When cache has one certificate in it (also is default)
- defaultCert := Certificate{Names: []string{"example.com", ""}}
- certCache[""] = defaultCert
- certCache["example.com"] = defaultCert
- if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
- t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
- }
- if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" {
- t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
- }
-
- // When retrieving wildcard certificate
- certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}}
- if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
- t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
- }
-
- // When no certificate matches, the default is returned
- if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted {
- t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
- } else if cert.Names[0] != "example.com" {
- t.Errorf("Expected default cert, got: %v", cert)
- }
-}
-
-func TestCacheCertificate(t *testing.T) {
- defer func() { certCache = make(map[string]Certificate) }()
-
- cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}})
- if _, ok := certCache["example.com"]; !ok {
- t.Error("Expected first cert to be cached by key 'example.com', but it wasn't")
- }
- if _, ok := certCache["sub.example.com"]; !ok {
- t.Error("Expected first cert to be cached by key 'sub.exmaple.com', but it wasn't")
- }
- if cert, ok := certCache[""]; !ok || cert.Names[2] != "" {
- t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't")
- }
-
- cacheCertificate(Certificate{Names: []string{"example2.com"}})
- if _, ok := certCache["example2.com"]; !ok {
- t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't")
- }
- if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" {
- t.Error("Expected second cert to NOT be cached as default, but it was")
- }
-}
diff --git a/core/https/client.go b/core/https/client.go
deleted file mode 100644
index e9e8cd82c..000000000
--- a/core/https/client.go
+++ /dev/null
@@ -1,215 +0,0 @@
-package https
-
-import (
- "encoding/json"
- "errors"
- "fmt"
- "io/ioutil"
- "net"
- "sync"
- "time"
-
- "github.com/miekg/coredns/server"
- "github.com/xenolf/lego/acme"
-)
-
-// acmeMu ensures that only one ACME challenge occurs at a time.
-var acmeMu sync.Mutex
-
-// ACMEClient is an acme.Client with custom state attached.
-type ACMEClient struct {
- *acme.Client
- AllowPrompts bool // if false, we assume AlternatePort must be used
-}
-
-// NewACMEClient creates a new ACMEClient given an email and whether
-// prompting the user is allowed. Clients should not be kept and
-// re-used over long periods of time, but immediate re-use is more
-// efficient than re-creating on every iteration.
-var NewACMEClient = func(email string, allowPrompts bool) (*ACMEClient, error) {
- // Look up or create the LE user account
- leUser, err := getUser(email)
- if err != nil {
- return nil, err
- }
-
- // The client facilitates our communication with the CA server.
- client, err := acme.NewClient(CAUrl, &leUser, KeyType)
- if err != nil {
- return nil, err
- }
-
- // If not registered, the user must register an account with the CA
- // and agree to terms
- if leUser.Registration == nil {
- reg, err := client.Register()
- if err != nil {
- return nil, errors.New("registration error: " + err.Error())
- }
- leUser.Registration = reg
-
- if allowPrompts { // can't prompt a user who isn't there
- if !Agreed && reg.TosURL == "" {
- Agreed = promptUserAgreement(saURL, false) // TODO - latest URL
- }
- if !Agreed && reg.TosURL == "" {
- return nil, errors.New("user must agree to terms")
- }
- }
-
- err = client.AgreeToTOS()
- if err != nil {
- saveUser(leUser) // Might as well try, right?
- return nil, errors.New("error agreeing to terms: " + err.Error())
- }
-
- // save user to the file system
- err = saveUser(leUser)
- if err != nil {
- return nil, errors.New("could not save user: " + err.Error())
- }
- }
-
- return &ACMEClient{
- Client: client,
- AllowPrompts: allowPrompts,
- }, nil
-}
-
-// NewACMEClientGetEmail creates a new ACMEClient and gets an email
-// address at the same time (a server config is required, since it
-// may contain an email address in it).
-func NewACMEClientGetEmail(config server.Config, allowPrompts bool) (*ACMEClient, error) {
- return NewACMEClient(getEmail(config, allowPrompts), allowPrompts)
-}
-
-// Configure configures c according to bindHost, which is the host (not
-// whole address) to bind the listener to in solving the http and tls-sni
-// challenges.
-func (c *ACMEClient) Configure(bindHost string) {
- // If we allow prompts, operator must be present. In our case,
- // that is synonymous with saying the server is not already
- // started. So if the user is still there, we don't use
- // AlternatePort because we don't need to proxy the challenges.
- // Conversely, if the operator is not there, the server has
- // already started and we need to proxy the challenge.
- if c.AllowPrompts {
- // Operator is present; server is not already listening
- c.SetHTTPAddress(net.JoinHostPort(bindHost, ""))
- c.SetTLSAddress(net.JoinHostPort(bindHost, ""))
- //c.ExcludeChallenges([]acme.Challenge{acme.DNS01})
- } else {
- // Operator is not present; server is started, so proxy challenges
- c.SetHTTPAddress(net.JoinHostPort(bindHost, AlternatePort))
- c.SetTLSAddress(net.JoinHostPort(bindHost, AlternatePort))
- //c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01})
- }
- c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) // TODO: can we proxy TLS challenges? and we should support DNS...
-}
-
-// Obtain obtains a single certificate for names. It stores the certificate
-// on the disk if successful.
-func (c *ACMEClient) Obtain(names []string) error {
-Attempts:
- for attempts := 0; attempts < 2; attempts++ {
- acmeMu.Lock()
- certificate, failures := c.ObtainCertificate(names, true, nil)
- acmeMu.Unlock()
- if len(failures) > 0 {
- // Error - try to fix it or report it to the user and abort
- var errMsg string // we'll combine all the failures into a single error message
- var promptedForAgreement bool // only prompt user for agreement at most once
-
- for errDomain, obtainErr := range failures {
- // TODO: Double-check, will obtainErr ever be nil?
- if tosErr, ok := obtainErr.(acme.TOSError); ok {
- // Terms of Service agreement error; we can probably deal with this
- if !Agreed && !promptedForAgreement && c.AllowPrompts {
- Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
- promptedForAgreement = true
- }
- if Agreed || !c.AllowPrompts {
- err := c.AgreeToTOS()
- if err != nil {
- return errors.New("error agreeing to updated terms: " + err.Error())
- }
- continue Attempts
- }
- }
-
- // If user did not agree or it was any other kind of error, just append to the list of errors
- errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
- }
- return errors.New(errMsg)
- }
-
- // Success - immediately save the certificate resource
- err := saveCertResource(certificate)
- if err != nil {
- return fmt.Errorf("error saving assets for %v: %v", names, err)
- }
-
- break
- }
-
- return nil
-}
-
-// Renew renews the managed certificate for name. Right now our storage
-// mechanism only supports one name per certificate, so this function only
-// accepts one domain as input. It can be easily modified to support SAN
-// certificates if, one day, they become desperately needed enough that our
-// storage mechanism is upgraded to be more complex to support SAN certs.
-//
-// Anyway, this function is safe for concurrent use.
-func (c *ACMEClient) Renew(name string) error {
- // Prepare for renewal (load PEM cert, key, and meta)
- certBytes, err := ioutil.ReadFile(storage.SiteCertFile(name))
- if err != nil {
- return err
- }
- keyBytes, err := ioutil.ReadFile(storage.SiteKeyFile(name))
- if err != nil {
- return err
- }
- metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(name))
- if err != nil {
- return err
- }
- var certMeta acme.CertificateResource
- err = json.Unmarshal(metaBytes, &certMeta)
- certMeta.Certificate = certBytes
- certMeta.PrivateKey = keyBytes
-
- // Perform renewal and retry if necessary, but not too many times.
- var newCertMeta acme.CertificateResource
- var success bool
- for attempts := 0; attempts < 2; attempts++ {
- acmeMu.Lock()
- newCertMeta, err = c.RenewCertificate(certMeta, true)
- acmeMu.Unlock()
- if err == nil {
- success = true
- break
- }
-
- // If the legal terms changed and need to be agreed to again,
- // we can handle that.
- if _, ok := err.(acme.TOSError); ok {
- err := c.AgreeToTOS()
- if err != nil {
- return err
- }
- continue
- }
-
- // For any other kind of error, wait 10s and try again.
- time.Sleep(10 * time.Second)
- }
-
- if !success {
- return errors.New("too many renewal attempts; last error: " + err.Error())
- }
-
- return saveCertResource(newCertMeta)
-}
diff --git a/core/https/crypto.go b/core/https/crypto.go
deleted file mode 100644
index 7971bda36..000000000
--- a/core/https/crypto.go
+++ /dev/null
@@ -1,57 +0,0 @@
-package https
-
-import (
- "crypto"
- "crypto/ecdsa"
- "crypto/rsa"
- "crypto/x509"
- "encoding/pem"
- "errors"
- "io/ioutil"
- "os"
-)
-
-// loadPrivateKey loads a PEM-encoded ECC/RSA private key from file.
-func loadPrivateKey(file string) (crypto.PrivateKey, error) {
- keyBytes, err := ioutil.ReadFile(file)
- if err != nil {
- return nil, err
- }
- keyBlock, _ := pem.Decode(keyBytes)
-
- switch keyBlock.Type {
- case "RSA PRIVATE KEY":
- return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
- case "EC PRIVATE KEY":
- return x509.ParseECPrivateKey(keyBlock.Bytes)
- }
-
- return nil, errors.New("unknown private key type")
-}
-
-// savePrivateKey saves a PEM-encoded ECC/RSA private key to file.
-func savePrivateKey(key crypto.PrivateKey, file string) error {
- var pemType string
- var keyBytes []byte
- switch key := key.(type) {
- case *ecdsa.PrivateKey:
- var err error
- pemType = "EC"
- keyBytes, err = x509.MarshalECPrivateKey(key)
- if err != nil {
- return err
- }
- case *rsa.PrivateKey:
- pemType = "RSA"
- keyBytes = x509.MarshalPKCS1PrivateKey(key)
- }
-
- pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes}
- keyOut, err := os.Create(file)
- if err != nil {
- return err
- }
- keyOut.Chmod(0600)
- defer keyOut.Close()
- return pem.Encode(keyOut, &pemKey)
-}
diff --git a/core/https/crypto_test.go b/core/https/crypto_test.go
deleted file mode 100644
index 07d2af5c7..000000000
--- a/core/https/crypto_test.go
+++ /dev/null
@@ -1,111 +0,0 @@
-package https
-
-import (
- "bytes"
- "crypto"
- "crypto/ecdsa"
- "crypto/elliptic"
- "crypto/rand"
- "crypto/rsa"
- "crypto/x509"
- "os"
- "runtime"
- "testing"
-)
-
-func TestSaveAndLoadRSAPrivateKey(t *testing.T) {
- keyFile := "test.key"
- defer os.Remove(keyFile)
-
- privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
- if err != nil {
- t.Fatal(err)
- }
-
- // test save
- err = savePrivateKey(privateKey, keyFile)
- if err != nil {
- t.Fatal("error saving private key:", err)
- }
-
- // it doesn't make sense to test file permission on windows
- if runtime.GOOS != "windows" {
- // get info of the key file
- info, err := os.Stat(keyFile)
- if err != nil {
- t.Fatal("error stating private key:", err)
- }
- // verify permission of key file is correct
- if info.Mode().Perm() != 0600 {
- t.Error("Expected key file to have permission 0600, but it wasn't")
- }
- }
-
- // test load
- loadedKey, err := loadPrivateKey(keyFile)
- if err != nil {
- t.Error("error loading private key:", err)
- }
-
- // verify loaded key is correct
- if !PrivateKeysSame(privateKey, loadedKey) {
- t.Error("Expected key bytes to be the same, but they weren't")
- }
-}
-
-func TestSaveAndLoadECCPrivateKey(t *testing.T) {
- keyFile := "test.key"
- defer os.Remove(keyFile)
-
- privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
- if err != nil {
- t.Fatal(err)
- }
-
- // test save
- err = savePrivateKey(privateKey, keyFile)
- if err != nil {
- t.Fatal("error saving private key:", err)
- }
-
- // it doesn't make sense to test file permission on windows
- if runtime.GOOS != "windows" {
- // get info of the key file
- info, err := os.Stat(keyFile)
- if err != nil {
- t.Fatal("error stating private key:", err)
- }
- // verify permission of key file is correct
- if info.Mode().Perm() != 0600 {
- t.Error("Expected key file to have permission 0600, but it wasn't")
- }
- }
-
- // test load
- loadedKey, err := loadPrivateKey(keyFile)
- if err != nil {
- t.Error("error loading private key:", err)
- }
-
- // verify loaded key is correct
- if !PrivateKeysSame(privateKey, loadedKey) {
- t.Error("Expected key bytes to be the same, but they weren't")
- }
-}
-
-// PrivateKeysSame compares the bytes of a and b and returns true if they are the same.
-func PrivateKeysSame(a, b crypto.PrivateKey) bool {
- return bytes.Equal(PrivateKeyBytes(a), PrivateKeyBytes(b))
-}
-
-// PrivateKeyBytes returns the bytes of DER-encoded key.
-func PrivateKeyBytes(key crypto.PrivateKey) []byte {
- var keyBytes []byte
- switch key := key.(type) {
- case *rsa.PrivateKey:
- keyBytes = x509.MarshalPKCS1PrivateKey(key)
- case *ecdsa.PrivateKey:
- keyBytes, _ = x509.MarshalECPrivateKey(key)
- }
- return keyBytes
-}
diff --git a/core/https/handler.go b/core/https/handler.go
deleted file mode 100644
index f3139f54e..000000000
--- a/core/https/handler.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package https
-
-import (
- "crypto/tls"
- "log"
- "net/http"
- "net/http/httputil"
- "net/url"
- "strings"
-)
-
-const challengeBasePath = "/.well-known/acme-challenge"
-
-// RequestCallback proxies challenge requests to ACME client if the
-// request path starts with challengeBasePath. It returns true if it
-// handled the request and no more needs to be done; it returns false
-// if this call was a no-op and the request still needs handling.
-func RequestCallback(w http.ResponseWriter, r *http.Request) bool {
- if strings.HasPrefix(r.URL.Path, challengeBasePath) {
- scheme := "http"
- if r.TLS != nil {
- scheme = "https"
- }
-
- upstream, err := url.Parse(scheme + "://localhost:" + AlternatePort)
- if err != nil {
- w.WriteHeader(http.StatusInternalServerError)
- log.Printf("[ERROR] ACME proxy handler: %v", err)
- return true
- }
-
- proxy := httputil.NewSingleHostReverseProxy(upstream)
- proxy.Transport = &http.Transport{
- TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs
- }
- proxy.ServeHTTP(w, r)
-
- return true
- }
-
- return false
-}
diff --git a/core/https/handler_test.go b/core/https/handler_test.go
deleted file mode 100644
index 016799ffb..000000000
--- a/core/https/handler_test.go
+++ /dev/null
@@ -1,63 +0,0 @@
-package https
-
-import (
- "net"
- "net/http"
- "net/http/httptest"
- "testing"
-)
-
-func TestRequestCallbackNoOp(t *testing.T) {
- // try base paths that aren't handled by this handler
- for _, url := range []string{
- "http://localhost/",
- "http://localhost/foo.html",
- "http://localhost/.git",
- "http://localhost/.well-known/",
- "http://localhost/.well-known/acme-challenging",
- } {
- req, err := http.NewRequest("GET", url, nil)
- if err != nil {
- t.Fatalf("Could not craft request, got error: %v", err)
- }
- rw := httptest.NewRecorder()
- if RequestCallback(rw, req) {
- t.Errorf("Got true with this URL, but shouldn't have: %s", url)
- }
- }
-}
-
-func TestRequestCallbackSuccess(t *testing.T) {
- expectedPath := challengeBasePath + "/asdf"
-
- // Set up fake acme handler backend to make sure proxying succeeds
- var proxySuccess bool
- ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- proxySuccess = true
- if r.URL.Path != expectedPath {
- t.Errorf("Expected path '%s' but got '%s' instead", expectedPath, r.URL.Path)
- }
- }))
-
- // Custom listener that uses the port we expect
- ln, err := net.Listen("tcp", "127.0.0.1:"+AlternatePort)
- if err != nil {
- t.Fatalf("Unable to start test server listener: %v", err)
- }
- ts.Listener = ln
-
- // Start our engines and run the test
- ts.Start()
- defer ts.Close()
- req, err := http.NewRequest("GET", "http://127.0.0.1:"+AlternatePort+expectedPath, nil)
- if err != nil {
- t.Fatalf("Could not craft request, got error: %v", err)
- }
- rw := httptest.NewRecorder()
-
- RequestCallback(rw, req)
-
- if !proxySuccess {
- t.Fatal("Expected request to be proxied, but it wasn't")
- }
-}
diff --git a/core/https/handshake.go b/core/https/handshake.go
deleted file mode 100644
index a05231c49..000000000
--- a/core/https/handshake.go
+++ /dev/null
@@ -1,316 +0,0 @@
-package https
-
-import (
- "bytes"
- "crypto/tls"
- "encoding/pem"
- "errors"
- "fmt"
- "log"
- "strings"
- "sync"
- "sync/atomic"
- "time"
-
- "github.com/miekg/coredns/server"
- "github.com/xenolf/lego/acme"
-)
-
-// GetCertificate gets a certificate to satisfy clientHello as long as
-// the certificate is already cached in memory. It will not be loaded
-// from disk or obtained from the CA during the handshake.
-//
-// This function is safe for use as a tls.Config.GetCertificate callback.
-func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
- cert, err := getCertDuringHandshake(clientHello.ServerName, false, false)
- return &cert.Certificate, err
-}
-
-// GetOrObtainCertificate will get a certificate to satisfy clientHello, even
-// if that means obtaining a new certificate from a CA during the handshake.
-// It first checks the in-memory cache, then accesses disk, then accesses the
-// network if it must. An obtained certificate will be stored on disk and
-// cached in memory.
-//
-// This function is safe for use as a tls.Config.GetCertificate callback.
-func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
- cert, err := getCertDuringHandshake(clientHello.ServerName, true, true)
- return &cert.Certificate, err
-}
-
-// getCertDuringHandshake will get a certificate for name. It first tries
-// the in-memory cache. If no certificate for name is in the cache and if
-// loadIfNecessary == true, it goes to disk to load it into the cache and
-// serve it. If it's not on disk and if obtainIfNecessary == true, the
-// certificate will be obtained from the CA, cached, and served. If
-// obtainIfNecessary is true, then loadIfNecessary must also be set to true.
-// An error will be returned if and only if no certificate is available.
-//
-// This function is safe for concurrent use.
-func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
- // First check our in-memory cache to see if we've already loaded it
- cert, matched, defaulted := getCertificate(name)
- if matched {
- return cert, nil
- }
-
- if loadIfNecessary {
- // Then check to see if we have one on disk
- loadedCert, err := cacheManagedCertificate(name, true)
- if err == nil {
- loadedCert, err = handshakeMaintenance(name, loadedCert)
- if err != nil {
- log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
- }
- return loadedCert, nil
- }
-
- if obtainIfNecessary {
- // By this point, we need to ask the CA for a certificate
-
- name = strings.ToLower(name)
-
- // Make sure aren't over any applicable limits
- err := checkLimitsForObtainingNewCerts(name)
- if err != nil {
- return Certificate{}, err
- }
-
- // TODO(miek): deleted, tls will be enabled when a keyword is specified.
- // Obtain certificate from the CA
- return obtainOnDemandCertificate(name)
- }
- }
-
- if defaulted {
- return cert, nil
- }
-
- return Certificate{}, errors.New("no certificate for " + name)
-}
-
-// checkLimitsForObtainingNewCerts checks to see if name can be issued right
-// now according to mitigating factors we keep track of and preferences the
-// user has set. If a non-nil error is returned, do not issue a new certificate
-// for name.
-func checkLimitsForObtainingNewCerts(name string) error {
- // User can set hard limit for number of certs for the process to issue
- if onDemandMaxIssue > 0 && atomic.LoadInt32(OnDemandIssuedCount) >= onDemandMaxIssue {
- return fmt.Errorf("%s: maximum certificates issued (%d)", name, onDemandMaxIssue)
- }
-
- // Make sure name hasn't failed a challenge recently
- failedIssuanceMu.RLock()
- when, ok := failedIssuance[name]
- failedIssuanceMu.RUnlock()
- if ok {
- return fmt.Errorf("%s: throttled; refusing to issue cert since last attempt on %s failed", name, when.String())
- }
-
- // Make sure, if we've issued a few certificates already, that we haven't
- // issued any recently
- lastIssueTimeMu.Lock()
- since := time.Since(lastIssueTime)
- lastIssueTimeMu.Unlock()
- if atomic.LoadInt32(OnDemandIssuedCount) >= 10 && since < 10*time.Minute {
- return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since)
- }
-
- // 👍Good to go
- return nil
-}
-
-// obtainOnDemandCertificate obtains a certificate for name for the given
-// name. If another goroutine has already started obtaining a cert for
-// name, it will wait and use what the other goroutine obtained.
-//
-// This function is safe for use by multiple concurrent goroutines.
-func obtainOnDemandCertificate(name string) (Certificate, error) {
- // We must protect this process from happening concurrently, so synchronize.
- obtainCertWaitChansMu.Lock()
- wait, ok := obtainCertWaitChans[name]
- if ok {
- // lucky us -- another goroutine is already obtaining the certificate.
- // wait for it to finish obtaining the cert and then we'll use it.
- obtainCertWaitChansMu.Unlock()
- <-wait
- return getCertDuringHandshake(name, true, false)
- }
-
- // looks like it's up to us to do all the work and obtain the cert
- wait = make(chan struct{})
- obtainCertWaitChans[name] = wait
- obtainCertWaitChansMu.Unlock()
-
- // Unblock waiters and delete waitgroup when we return
- defer func() {
- obtainCertWaitChansMu.Lock()
- close(wait)
- delete(obtainCertWaitChans, name)
- obtainCertWaitChansMu.Unlock()
- }()
-
- log.Printf("[INFO] Obtaining new certificate for %s", name)
-
- // obtain cert
- client, err := NewACMEClientGetEmail(server.Config{}, false)
- if err != nil {
- return Certificate{}, errors.New("error creating client: " + err.Error())
- }
- client.Configure("") // TODO: which BindHost?
- err = client.Obtain([]string{name})
- if err != nil {
- // Failed to solve challenge, so don't allow another on-demand
- // issue for this name to be attempted for a little while.
- failedIssuanceMu.Lock()
- failedIssuance[name] = time.Now()
- go func(name string) {
- time.Sleep(5 * time.Minute)
- failedIssuanceMu.Lock()
- delete(failedIssuance, name)
- failedIssuanceMu.Unlock()
- }(name)
- failedIssuanceMu.Unlock()
- return Certificate{}, err
- }
-
- // Success - update counters and stuff
- atomic.AddInt32(OnDemandIssuedCount, 1)
- lastIssueTimeMu.Lock()
- lastIssueTime = time.Now()
- lastIssueTimeMu.Unlock()
-
- // The certificate is already on disk; now just start over to load it and serve it
- return getCertDuringHandshake(name, true, false)
-}
-
-// handshakeMaintenance performs a check on cert for expiration and OCSP
-// validity.
-//
-// This function is safe for use by multiple concurrent goroutines.
-func handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
- // Check cert expiration
- timeLeft := cert.NotAfter.Sub(time.Now().UTC())
- if timeLeft < renewDurationBefore {
- log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
- return renewDynamicCertificate(name)
- }
-
- // Check OCSP staple validity
- if cert.OCSP != nil {
- refreshTime := cert.OCSP.ThisUpdate.Add(cert.OCSP.NextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
- if time.Now().After(refreshTime) {
- err := stapleOCSP(&cert, nil)
- if err != nil {
- // An error with OCSP stapling is not the end of the world, and in fact, is
- // quite common considering not all certs have issuer URLs that support it.
- log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
- }
- certCacheMu.Lock()
- certCache[name] = cert
- certCacheMu.Unlock()
- }
- }
-
- return cert, nil
-}
-
-// renewDynamicCertificate renews currentCert using the clientHello. It returns the
-// certificate to use and an error, if any. currentCert may be returned even if an
-// error occurs, since we perform renewals before they expire and it may still be
-// usable. name should already be lower-cased before calling this function.
-//
-// This function is safe for use by multiple concurrent goroutines.
-func renewDynamicCertificate(name string) (Certificate, error) {
- obtainCertWaitChansMu.Lock()
- wait, ok := obtainCertWaitChans[name]
- if ok {
- // lucky us -- another goroutine is already renewing the certificate.
- // wait for it to finish, then we'll use the new one.
- obtainCertWaitChansMu.Unlock()
- <-wait
- return getCertDuringHandshake(name, true, false)
- }
-
- // looks like it's up to us to do all the work and renew the cert
- wait = make(chan struct{})
- obtainCertWaitChans[name] = wait
- obtainCertWaitChansMu.Unlock()
-
- // unblock waiters and delete waitgroup when we return
- defer func() {
- obtainCertWaitChansMu.Lock()
- close(wait)
- delete(obtainCertWaitChans, name)
- obtainCertWaitChansMu.Unlock()
- }()
-
- log.Printf("[INFO] Renewing certificate for %s", name)
-
- client, err := NewACMEClientGetEmail(server.Config{}, false)
- if err != nil {
- return Certificate{}, err
- }
- client.Configure("") // TODO: Bind address of relevant listener, yuck
- err = client.Renew(name)
- if err != nil {
- return Certificate{}, err
- }
-
- return getCertDuringHandshake(name, true, false)
-}
-
-// stapleOCSP staples OCSP information to cert for hostname name.
-// If you have it handy, you should pass in the PEM-encoded certificate
-// bundle; otherwise the DER-encoded cert will have to be PEM-encoded.
-// If you don't have the PEM blocks handy, just pass in nil.
-//
-// Errors here are not necessarily fatal, it could just be that the
-// certificate doesn't have an issuer URL.
-func stapleOCSP(cert *Certificate, pemBundle []byte) error {
- if pemBundle == nil {
- // The function in the acme package that gets OCSP requires a PEM-encoded cert
- bundle := new(bytes.Buffer)
- for _, derBytes := range cert.Certificate.Certificate {
- pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
- }
- pemBundle = bundle.Bytes()
- }
-
- ocspBytes, ocspResp, err := acme.GetOCSPForCert(pemBundle)
- if err != nil {
- return err
- }
-
- cert.Certificate.OCSPStaple = ocspBytes
- cert.OCSP = ocspResp
-
- return nil
-}
-
-// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
-var obtainCertWaitChans = make(map[string]chan struct{})
-var obtainCertWaitChansMu sync.Mutex
-
-// OnDemandIssuedCount is the number of certificates that have been issued
-// on-demand by this process. It is only safe to modify this count atomically.
-// If it reaches onDemandMaxIssue, on-demand issuances will fail.
-var OnDemandIssuedCount = new(int32)
-
-// onDemandMaxIssue is set based on max_certs in tls config. It specifies the
-// maximum number of certificates that can be issued.
-// TODO: This applies globally, but we should probably make a server-specific
-// way to keep track of these limits and counts, since it's specified in the
-// Corefile...
-var onDemandMaxIssue int32
-
-// failedIssuance is a set of names that we recently failed to get a
-// certificate for from the ACME CA. They are removed after some time.
-// When a name is in this map, do not issue a certificate for it on-demand.
-var failedIssuance = make(map[string]time.Time)
-var failedIssuanceMu sync.RWMutex
-
-// lastIssueTime records when we last obtained a certificate successfully.
-// If this value is recent, do not make any on-demand certificate requests.
-var lastIssueTime time.Time
-var lastIssueTimeMu sync.Mutex
diff --git a/core/https/handshake_test.go b/core/https/handshake_test.go
deleted file mode 100644
index cf70eb17d..000000000
--- a/core/https/handshake_test.go
+++ /dev/null
@@ -1,54 +0,0 @@
-package https
-
-import (
- "crypto/tls"
- "crypto/x509"
- "testing"
-)
-
-func TestGetCertificate(t *testing.T) {
- defer func() { certCache = make(map[string]Certificate) }()
-
- hello := &tls.ClientHelloInfo{ServerName: "example.com"}
- helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
- helloNoSNI := &tls.ClientHelloInfo{}
- helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
-
- // When cache is empty
- if cert, err := GetCertificate(hello); err == nil {
- t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
- }
- if cert, err := GetCertificate(helloNoSNI); err == nil {
- t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
- }
-
- // When cache has one certificate in it (also is default)
- defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
- certCache[""] = defaultCert
- certCache["example.com"] = defaultCert
- if cert, err := GetCertificate(hello); err != nil {
- t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
- } else if cert.Leaf.DNSNames[0] != "example.com" {
- t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
- }
- if cert, err := GetCertificate(helloNoSNI); err != nil {
- t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
- } else if cert.Leaf.DNSNames[0] != "example.com" {
- t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
- }
-
- // When retrieving wildcard certificate
- certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
- if cert, err := GetCertificate(helloSub); err != nil {
- t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
- } else if cert.Leaf.DNSNames[0] != "*.example.com" {
- t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
- }
-
- // When no certificate matches, the default is returned
- if cert, err := GetCertificate(helloNoMatch); err != nil {
- t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
- } else if cert.Leaf.DNSNames[0] != "example.com" {
- t.Errorf("Expected default cert with no matches, got: %v", cert)
- }
-}
diff --git a/core/https/https.go b/core/https/https.go
deleted file mode 100644
index 99ef2fef6..000000000
--- a/core/https/https.go
+++ /dev/null
@@ -1,339 +0,0 @@
-// Package https facilitates the management of TLS assets and integrates
-// Let's Encrypt functionality into CoreDNS with first-class support for
-// creating and renewing certificates automatically. It is designed to
-// configure sites for HTTPS by default.
-package https
-
-import (
- "encoding/json"
- "errors"
- "io/ioutil"
- "net"
- "os"
-
- "github.com/miekg/coredns/server"
- "github.com/xenolf/lego/acme"
-)
-
-// Activate sets up TLS for each server config in configs
-// as needed; this consists of acquiring and maintaining
-// certificates and keys for qualifying configs and enabling
-// OCSP stapling for all TLS-enabled configs.
-//
-// This function may prompt the user to provide an email
-// address if none is available through other means. It
-// prefers the email address specified in the config, but
-// if that is not available it will check the command line
-// argument. If absent, it will use the most recent email
-// address from last time. If there isn't one, the user
-// will be prompted and shown SA link.
-//
-// Also note that calling this function activates asset
-// management automatically, which keeps certificates
-// renewed and OCSP stapling updated.
-//
-// Activate returns the updated list of configs, since
-// some may have been appended, for example, to redirect
-// plaintext HTTP requests to their HTTPS counterpart.
-// This function only appends; it does not splice.
-func Activate(configs []server.Config) ([]server.Config, error) {
- // just in case previous caller forgot...
- Deactivate()
-
- // pre-screen each config and earmark the ones that qualify for managed TLS
- MarkQualified(configs)
-
- // place certificates and keys on disk
- err := ObtainCerts(configs, true, false)
- if err != nil {
- return configs, err
- }
-
- // update TLS configurations
- err = EnableTLS(configs, true)
- if err != nil {
- return configs, err
- }
-
- // renew all relevant certificates that need renewal. this is important
- // to do right away for a couple reasons, mainly because each restart,
- // the renewal ticker is reset, so if restarts happen more often than
- // the ticker interval, renewals would never happen. but doing
- // it right away at start guarantees that renewals aren't missed.
- err = renewManagedCertificates(true)
- if err != nil {
- return configs, err
- }
-
- // keep certificates renewed and OCSP stapling updated
- go maintainAssets(stopChan)
-
- return configs, nil
-}
-
-// Deactivate cleans up long-term, in-memory resources
-// allocated by calling Activate(). Essentially, it stops
-// the asset maintainer from running, meaning that certificates
-// will not be renewed, OCSP staples will not be updated, etc.
-func Deactivate() (err error) {
- defer func() {
- if rec := recover(); rec != nil {
- err = errors.New("already deactivated")
- }
- }()
- close(stopChan)
- stopChan = make(chan struct{})
- return
-}
-
-// MarkQualified scans each config and, if it qualifies for managed
-// TLS, it sets the Managed field of the TLSConfig to true.
-func MarkQualified(configs []server.Config) {
- for i := 0; i < len(configs); i++ {
- if ConfigQualifies(configs[i]) {
- configs[i].TLS.Managed = true
- }
- }
-}
-
-// ObtainCerts obtains certificates for all these configs as long as a
-// certificate does not already exist on disk. It does not modify the
-// configs at all; it only obtains and stores certificates and keys to
-// the disk. If allowPrompts is true, the user may be shown a prompt.
-// If proxyACME is true, the ACME challenges will be proxied to our alt port.
-func ObtainCerts(configs []server.Config, allowPrompts, proxyACME bool) error {
- // We group configs by email so we don't make the same clients over and
- // over. This has the potential to prompt the user for an email, but we
- // prevent that by assuming that if we already have a listener that can
- // proxy ACME challenge requests, then the server is already running and
- // the operator is no longer present.
- groupedConfigs := groupConfigsByEmail(configs, allowPrompts)
-
- for email, group := range groupedConfigs {
- // Wait as long as we can before creating the client, because it
- // may not be needed, for example, if we already have what we
- // need on disk. Creating a client involves the network and
- // potentially prompting the user, etc., so only do if necessary.
- var client *ACMEClient
-
- for _, cfg := range group {
- if existingCertAndKey(cfg.Host) {
- continue
- }
-
- // Now we definitely do need a client
- if client == nil {
- var err error
- client, err = NewACMEClient(email, allowPrompts)
- if err != nil {
- return errors.New("error creating client: " + err.Error())
- }
- }
-
- // c.Configure assumes that allowPrompts == !proxyACME,
- // but that's not always true. For example, a restart where
- // the user isn't present and we're not listening on port 80.
- // TODO: This could probably be refactored better.
- if proxyACME {
- client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, AlternatePort))
- client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, AlternatePort))
- client.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01})
- } else {
- client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, ""))
- client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, ""))
- client.ExcludeChallenges([]acme.Challenge{acme.DNS01})
- }
-
- err := client.Obtain([]string{cfg.Host})
- if err != nil {
- return err
- }
- }
- }
-
- return nil
-}
-
-// groupConfigsByEmail groups configs by the email address to be used by an
-// ACME client. It only groups configs that have TLS enabled and that are
-// marked as Managed. If userPresent is true, the operator MAY be prompted
-// for an email address.
-func groupConfigsByEmail(configs []server.Config, userPresent bool) map[string][]server.Config {
- initMap := make(map[string][]server.Config)
- for _, cfg := range configs {
- if !cfg.TLS.Managed {
- continue
- }
- leEmail := getEmail(cfg, userPresent)
- initMap[leEmail] = append(initMap[leEmail], cfg)
- }
- return initMap
-}
-
-// EnableTLS configures each config to use TLS according to default settings.
-// It will only change configs that are marked as managed, and assumes that
-// certificates and keys are already on disk. If loadCertificates is true,
-// the certificates will be loaded from disk into the cache for this process
-// to use. If false, TLS will still be enabled and configured with default
-// settings, but no certificates will be parsed loaded into the cache, and
-// the returned error value will always be nil.
-func EnableTLS(configs []server.Config, loadCertificates bool) error {
- for i := 0; i < len(configs); i++ {
- if !configs[i].TLS.Managed {
- continue
- }
- configs[i].TLS.Enabled = true
- if loadCertificates {
- _, err := cacheManagedCertificate(configs[i].Host, false)
- if err != nil {
- return err
- }
- }
- setDefaultTLSParams(&configs[i])
- }
- return nil
-}
-
-// hostHasOtherPort returns true if there is another config in the list with the same
-// hostname that has port otherPort, or false otherwise. All the configs are checked
-// against the hostname of allConfigs[thisConfigIdx].
-func hostHasOtherPort(allConfigs []server.Config, thisConfigIdx int, otherPort string) bool {
- for i, otherCfg := range allConfigs {
- if i == thisConfigIdx {
- continue // has to be a config OTHER than the one we're comparing against
- }
- if otherCfg.Host == allConfigs[thisConfigIdx].Host && otherCfg.Port == otherPort {
- return true
- }
- }
- return false
-}
-
-// ConfigQualifies returns true if cfg qualifies for
-// fully managed TLS (but not on-demand TLS, which is
-// not considered here). It does NOT check to see if a
-// cert and key already exist for the config. If the
-// config does qualify, you should set cfg.TLS.Managed
-// to true and check that instead, because the process of
-// setting up the config may make it look like it
-// doesn't qualify even though it originally did.
-func ConfigQualifies(cfg server.Config) bool {
- return (!cfg.TLS.Manual || cfg.TLS.OnDemand) && // user might provide own cert and key
-
- // user can force-disable automatic HTTPS for this host
- cfg.Port != "80" &&
- cfg.TLS.LetsEncryptEmail != "off" &&
-
- // we get can't certs for some kinds of hostnames, but
- // on-demand TLS allows empty hostnames at startup
- cfg.TLS.OnDemand
-}
-
-// existingCertAndKey returns true if the host has a certificate
-// and private key in storage already, false otherwise.
-func existingCertAndKey(host string) bool {
- _, err := os.Stat(storage.SiteCertFile(host))
- if err != nil {
- return false
- }
- _, err = os.Stat(storage.SiteKeyFile(host))
- if err != nil {
- return false
- }
- return true
-}
-
-// saveCertResource saves the certificate resource to disk. This
-// includes the certificate file itself, the private key, and the
-// metadata file.
-func saveCertResource(cert acme.CertificateResource) error {
- err := os.MkdirAll(storage.Site(cert.Domain), 0700)
- if err != nil {
- return err
- }
-
- // Save cert
- err = ioutil.WriteFile(storage.SiteCertFile(cert.Domain), cert.Certificate, 0600)
- if err != nil {
- return err
- }
-
- // Save private key
- err = ioutil.WriteFile(storage.SiteKeyFile(cert.Domain), cert.PrivateKey, 0600)
- if err != nil {
- return err
- }
-
- // Save cert metadata
- jsonBytes, err := json.MarshalIndent(&cert, "", "\t")
- if err != nil {
- return err
- }
- err = ioutil.WriteFile(storage.SiteMetaFile(cert.Domain), jsonBytes, 0600)
- if err != nil {
- return err
- }
-
- return nil
-}
-
-// Revoke revokes the certificate for host via ACME protocol.
-func Revoke(host string) error {
- if !existingCertAndKey(host) {
- return errors.New("no certificate and key for " + host)
- }
-
- email := getEmail(server.Config{Host: host}, true)
- if email == "" {
- return errors.New("email is required to revoke")
- }
-
- client, err := NewACMEClient(email, true)
- if err != nil {
- return err
- }
-
- certFile := storage.SiteCertFile(host)
- certBytes, err := ioutil.ReadFile(certFile)
- if err != nil {
- return err
- }
-
- err = client.RevokeCertificate(certBytes)
- if err != nil {
- return err
- }
-
- err = os.Remove(certFile)
- if err != nil {
- return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error())
- }
-
- return nil
-}
-
-var (
- // DefaultEmail represents the Let's Encrypt account email to use if none provided
- DefaultEmail string
-
- // Agreed indicates whether user has agreed to the Let's Encrypt SA
- Agreed bool
-
- // CAUrl represents the base URL to the CA's ACME endpoint
- CAUrl string
-)
-
-// AlternatePort is the port on which the acme client will open a
-// listener and solve the CA's challenges. If this alternate port
-// is used instead of the default port (80 or 443), then the
-// default port for the challenge must be forwarded to this one.
-const AlternatePort = "5033"
-
-// KeyType is the type to use for new keys.
-// This shouldn't need to change except for in tests;
-// the size can be drastically reduced for speed.
-var KeyType = acme.EC384
-
-// stopChan is used to signal the maintenance goroutine
-// to terminate.
-var stopChan chan struct{}
diff --git a/core/https/https_test.go b/core/https/https_test.go
deleted file mode 100644
index f19b3cde0..000000000
--- a/core/https/https_test.go
+++ /dev/null
@@ -1,323 +0,0 @@
-package https
-
-/*
-func TestHostQualifies(t *testing.T) {
- for i, test := range []struct {
- host string
- expect bool
- }{
- {"localhost", false},
- {"127.0.0.1", false},
- {"127.0.1.5", false},
- {"::1", false},
- {"[::1]", false},
- {"[::]", false},
- {"::", false},
- {"", false},
- {" ", false},
- {"0.0.0.0", false},
- {"192.168.1.3", false},
- {"10.0.2.1", false},
- {"169.112.53.4", false},
- {"foobar.com", true},
- {"sub.foobar.com", true},
- } {
- if HostQualifies(test.host) && !test.expect {
- t.Errorf("Test %d: Expected '%s' to NOT qualify, but it did", i, test.host)
- }
- if !HostQualifies(test.host) && test.expect {
- t.Errorf("Test %d: Expected '%s' to qualify, but it did NOT", i, test.host)
- }
- }
-}
-
-func TestConfigQualifies(t *testing.T) {
- for i, test := range []struct {
- cfg server.Config
- expect bool
- }{
- {server.Config{Host: ""}, false},
- {server.Config{Host: "localhost"}, false},
- {server.Config{Host: "123.44.3.21"}, false},
- {server.Config{Host: "example.com"}, true},
- {server.Config{Host: "example.com", TLS: server.TLSConfig{Manual: true}}, false},
- {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false},
- {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true},
- {server.Config{Host: "example.com", Scheme: "http"}, false},
- {server.Config{Host: "example.com", Port: "80"}, false},
- {server.Config{Host: "example.com", Port: "1234"}, true},
- {server.Config{Host: "example.com", Scheme: "https"}, true},
- {server.Config{Host: "example.com", Port: "80", Scheme: "https"}, false},
- } {
- if test.expect && !ConfigQualifies(test.cfg) {
- t.Errorf("Test %d: Expected config to qualify, but it did NOT: %#v", i, test.cfg)
- }
- if !test.expect && ConfigQualifies(test.cfg) {
- t.Errorf("Test %d: Expected config to NOT qualify, but it did: %#v", i, test.cfg)
- }
- }
-}
-
-func TestRedirPlaintextHost(t *testing.T) {
- cfg := redirPlaintextHost(server.Config{
- Host: "example.com",
- BindHost: "93.184.216.34",
- Port: "1234",
- })
-
- // Check host and port
- if actual, expected := cfg.Host, "example.com"; actual != expected {
- t.Errorf("Expected redir config to have host %s but got %s", expected, actual)
- }
- if actual, expected := cfg.BindHost, "93.184.216.34"; actual != expected {
- t.Errorf("Expected redir config to have bindhost %s but got %s", expected, actual)
- }
- if actual, expected := cfg.Port, "80"; actual != expected {
- t.Errorf("Expected redir config to have port '%s' but got '%s'", expected, actual)
- }
-
- // Make sure redirect handler is set up properly
- if cfg.Middleware == nil || len(cfg.Middleware) != 1 {
- t.Fatalf("Redir config middleware not set up properly; got: %#v", cfg.Middleware)
- }
-
- handler, ok := cfg.Middleware[0](nil).(redirect.Redirect)
- if !ok {
- t.Fatalf("Expected a redirect.Redirect middleware, but got: %#v", handler)
- }
- if len(handler.Rules) != 1 {
- t.Fatalf("Expected one redirect rule, got: %#v", handler.Rules)
- }
-
- // Check redirect rule for correctness
- if actual, expected := handler.Rules[0].FromScheme, "http"; actual != expected {
- t.Errorf("Expected redirect rule to be from scheme '%s' but is actually from '%s'", expected, actual)
- }
- if actual, expected := handler.Rules[0].FromPath, "/"; actual != expected {
- t.Errorf("Expected redirect rule to be for path '%s' but is actually for '%s'", expected, actual)
- }
- if actual, expected := handler.Rules[0].To, "https://{host}:1234{uri}"; actual != expected {
- t.Errorf("Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual)
- }
- if actual, expected := handler.Rules[0].Code, http.StatusMovedPermanently; actual != expected {
- t.Errorf("Expected redirect rule to have code %d but was %d", expected, actual)
- }
-
- // browsers can infer a default port from scheme, so make sure the port
- // doesn't get added in explicitly for default ports like 443 for https.
- cfg = redirPlaintextHost(server.Config{Host: "example.com", Port: "443"})
- handler, ok = cfg.Middleware[0](nil).(redirect.Redirect)
- if actual, expected := handler.Rules[0].To, "https://{host}{uri}"; actual != expected {
- t.Errorf("(Default Port) Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual)
- }
-}
-
-func TestSaveCertResource(t *testing.T) {
- storage = Storage("./le_test_save")
- defer func() {
- err := os.RemoveAll(string(storage))
- if err != nil {
- t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err)
- }
- }()
-
- domain := "example.com"
- certContents := "certificate"
- keyContents := "private key"
- metaContents := `{
- "domain": "example.com",
- "certUrl": "https://example.com/cert",
- "certStableUrl": "https://example.com/cert/stable"
-}`
-
- cert := acme.CertificateResource{
- Domain: domain,
- CertURL: "https://example.com/cert",
- CertStableURL: "https://example.com/cert/stable",
- PrivateKey: []byte(keyContents),
- Certificate: []byte(certContents),
- }
-
- err := saveCertResource(cert)
- if err != nil {
- t.Fatalf("Expected no error, got: %v", err)
- }
-
- certFile, err := ioutil.ReadFile(storage.SiteCertFile(domain))
- if err != nil {
- t.Errorf("Expected no error reading certificate file, got: %v", err)
- }
- if string(certFile) != certContents {
- t.Errorf("Expected certificate file to contain '%s', got '%s'", certContents, string(certFile))
- }
-
- keyFile, err := ioutil.ReadFile(storage.SiteKeyFile(domain))
- if err != nil {
- t.Errorf("Expected no error reading private key file, got: %v", err)
- }
- if string(keyFile) != keyContents {
- t.Errorf("Expected private key file to contain '%s', got '%s'", keyContents, string(keyFile))
- }
-
- metaFile, err := ioutil.ReadFile(storage.SiteMetaFile(domain))
- if err != nil {
- t.Errorf("Expected no error reading meta file, got: %v", err)
- }
- if string(metaFile) != metaContents {
- t.Errorf("Expected meta file to contain '%s', got '%s'", metaContents, string(metaFile))
- }
-}
-
-func TestExistingCertAndKey(t *testing.T) {
- storage = Storage("./le_test_existing")
- defer func() {
- err := os.RemoveAll(string(storage))
- if err != nil {
- t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err)
- }
- }()
-
- domain := "example.com"
-
- if existingCertAndKey(domain) {
- t.Errorf("Did NOT expect %v to have existing cert or key, but it did", domain)
- }
-
- err := saveCertResource(acme.CertificateResource{
- Domain: domain,
- PrivateKey: []byte("key"),
- Certificate: []byte("cert"),
- })
- if err != nil {
- t.Fatalf("Expected no error, got: %v", err)
- }
-
- if !existingCertAndKey(domain) {
- t.Errorf("Expected %v to have existing cert and key, but it did NOT", domain)
- }
-}
-
-func TestHostHasOtherPort(t *testing.T) {
- configs := []server.Config{
- {Host: "example.com", Port: "80"},
- {Host: "sub1.example.com", Port: "80"},
- {Host: "sub1.example.com", Port: "443"},
- }
-
- if hostHasOtherPort(configs, 0, "80") {
- t.Errorf(`Expected hostHasOtherPort(configs, 0, "80") to be false, but got true`)
- }
- if hostHasOtherPort(configs, 0, "443") {
- t.Errorf(`Expected hostHasOtherPort(configs, 0, "443") to be false, but got true`)
- }
- if !hostHasOtherPort(configs, 1, "443") {
- t.Errorf(`Expected hostHasOtherPort(configs, 1, "443") to be true, but got false`)
- }
-}
-
-func TestMakePlaintextRedirects(t *testing.T) {
- configs := []server.Config{
- // Happy path = standard redirect from 80 to 443
- {Host: "example.com", TLS: server.TLSConfig{Managed: true}},
-
- // Host on port 80 already defined; don't change it (no redirect)
- {Host: "sub1.example.com", Port: "80", Scheme: "http"},
- {Host: "sub1.example.com", TLS: server.TLSConfig{Managed: true}},
-
- // Redirect from port 80 to port 5000 in this case
- {Host: "sub2.example.com", Port: "5000", TLS: server.TLSConfig{Managed: true}},
-
- // Can redirect from 80 to either 443 or 5001, but choose 443
- {Host: "sub3.example.com", Port: "443", TLS: server.TLSConfig{Managed: true}},
- {Host: "sub3.example.com", Port: "5001", Scheme: "https", TLS: server.TLSConfig{Managed: true}},
- }
-
- result := MakePlaintextRedirects(configs)
- expectedRedirCount := 3
-
- if len(result) != len(configs)+expectedRedirCount {
- t.Errorf("Expected %d redirect(s) to be added, but got %d",
- expectedRedirCount, len(result)-len(configs))
- }
-}
-
-func TestEnableTLS(t *testing.T) {
- configs := []server.Config{
- {Host: "example.com", TLS: server.TLSConfig{Managed: true}},
- {}, // not managed - no changes!
- }
-
- EnableTLS(configs, false)
-
- if !configs[0].TLS.Enabled {
- t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false")
- }
- if configs[1].TLS.Enabled {
- t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true")
- }
-}
-
-func TestGroupConfigsByEmail(t *testing.T) {
- if groupConfigsByEmail([]server.Config{}, false) == nil {
- t.Errorf("With empty input, returned map was nil, but expected non-nil map")
- }
-
- configs := []server.Config{
- {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
- {Host: "sub1.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}},
- {Host: "sub2.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
- {Host: "sub3.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}},
- {Host: "sub4.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
- {Host: "sub5.example.com", TLS: server.TLSConfig{LetsEncryptEmail: ""}}, // not managed
- }
- DefaultEmail = "test@example.com"
-
- groups := groupConfigsByEmail(configs, true)
-
- if groups == nil {
- t.Fatalf("Returned map was nil, but expected values")
- }
-
- if len(groups) != 2 {
- t.Errorf("Expected 2 groups, got %d: %#v", len(groups), groups)
- }
- if len(groups["foo@bar"]) != 2 {
- t.Errorf("Expected 2 configs for foo@bar, got %d: %#v", len(groups["foobar"]), groups["foobar"])
- }
- if len(groups[DefaultEmail]) != 3 {
- t.Errorf("Expected 3 configs for %s, got %d: %#v", DefaultEmail, len(groups["foobar"]), groups["foobar"])
- }
-}
-
-func TestMarkQualified(t *testing.T) {
- // TODO: TestConfigQualifies and this test share the same config list...
- configs := []server.Config{
- {Host: ""},
- {Host: "localhost"},
- {Host: "123.44.3.21"},
- {Host: "example.com"},
- {Host: "example.com", TLS: server.TLSConfig{Manual: true}},
- {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}},
- {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}},
- {Host: "example.com", Scheme: "http"},
- {Host: "example.com", Port: "80"},
- {Host: "example.com", Port: "1234"},
- {Host: "example.com", Scheme: "https"},
- {Host: "example.com", Port: "80", Scheme: "https"},
- }
- expectedManagedCount := 4
-
- MarkQualified(configs)
-
- count := 0
- for _, cfg := range configs {
- if cfg.TLS.Managed {
- count++
- }
- }
-
- if count != expectedManagedCount {
- t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count)
- }
-}
-*/
diff --git a/core/https/maintain.go b/core/https/maintain.go
deleted file mode 100644
index 46fd3d1f5..000000000
--- a/core/https/maintain.go
+++ /dev/null
@@ -1,211 +0,0 @@
-package https
-
-import (
- "log"
- "time"
-
- "github.com/miekg/coredns/server"
-
- "golang.org/x/crypto/ocsp"
-)
-
-const (
- // RenewInterval is how often to check certificates for renewal.
- RenewInterval = 12 * time.Hour
-
- // OCSPInterval is how often to check if OCSP stapling needs updating.
- OCSPInterval = 1 * time.Hour
-)
-
-// maintainAssets is a permanently-blocking function
-// that loops indefinitely and, on a regular schedule, checks
-// certificates for expiration and initiates a renewal of certs
-// that are expiring soon. It also updates OCSP stapling and
-// performs other maintenance of assets.
-//
-// You must pass in the channel which you'll close when
-// maintenance should stop, to allow this goroutine to clean up
-// after itself and unblock.
-func maintainAssets(stopChan chan struct{}) {
- renewalTicker := time.NewTicker(RenewInterval)
- ocspTicker := time.NewTicker(OCSPInterval)
-
- for {
- select {
- case <-renewalTicker.C:
- log.Println("[INFO] Scanning for expiring certificates")
- renewManagedCertificates(false)
- log.Println("[INFO] Done checking certificates")
- case <-ocspTicker.C:
- log.Println("[INFO] Scanning for stale OCSP staples")
- updateOCSPStaples()
- log.Println("[INFO] Done checking OCSP staples")
- case <-stopChan:
- renewalTicker.Stop()
- ocspTicker.Stop()
- log.Println("[INFO] Stopped background maintenance routine")
- return
- }
- }
-}
-
-func renewManagedCertificates(allowPrompts bool) (err error) {
- var renewed, deleted []Certificate
- var client *ACMEClient
- visitedNames := make(map[string]struct{})
-
- certCacheMu.RLock()
- for name, cert := range certCache {
- if !cert.Managed {
- continue
- }
-
- // the list of names on this cert should never be empty...
- if cert.Names == nil || len(cert.Names) == 0 {
- log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v", name, cert.Names)
- deleted = append(deleted, cert)
- continue
- }
-
- // skip names whose certificate we've already renewed
- if _, ok := visitedNames[name]; ok {
- continue
- }
- for _, name := range cert.Names {
- visitedNames[name] = struct{}{}
- }
-
- timeLeft := cert.NotAfter.Sub(time.Now().UTC())
- if timeLeft < renewDurationBefore {
- log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
-
- if client == nil {
- client, err = NewACMEClientGetEmail(server.Config{}, allowPrompts)
- if err != nil {
- return err
- }
- client.Configure("") // TODO: Bind address of relevant listener, yuck
- }
-
- err := client.Renew(cert.Names[0]) // managed certs better have only one name
- if err != nil {
- if client.AllowPrompts && timeLeft < 0 {
- // Certificate renewal failed, the operator is present, and the certificate
- // is already expired; we should stop immediately and return the error. Note
- // that we used to do this any time a renewal failed at startup. However,
- // after discussion in https://github.com/miekg/coredns/issues/642 we decided to
- // only stop startup if the certificate is expired. We still log the error
- // otherwise.
- certCacheMu.RUnlock()
- return err
- }
- log.Printf("[ERROR] %v", err)
- if cert.OnDemand {
- deleted = append(deleted, cert)
- }
- } else {
- renewed = append(renewed, cert)
- }
- }
- }
- certCacheMu.RUnlock()
-
- // Apply changes to the cache
- for _, cert := range renewed {
- _, err := cacheManagedCertificate(cert.Names[0], cert.OnDemand)
- if err != nil {
- if client.AllowPrompts {
- return err // operator is present, so report error immediately
- }
- log.Printf("[ERROR] %v", err)
- }
- }
- for _, cert := range deleted {
- certCacheMu.Lock()
- for _, name := range cert.Names {
- delete(certCache, name)
- }
- certCacheMu.Unlock()
- }
-
- return nil
-}
-
-func updateOCSPStaples() {
- // Create a temporary place to store updates
- // until we release the potentially long-lived
- // read lock and use a short-lived write lock.
- type ocspUpdate struct {
- rawBytes []byte
- parsed *ocsp.Response
- }
- updated := make(map[string]ocspUpdate)
-
- // A single SAN certificate maps to multiple names, so we use this
- // set to make sure we don't waste cycles checking OCSP for the same
- // certificate multiple times.
- visited := make(map[string]struct{})
-
- certCacheMu.RLock()
- for name, cert := range certCache {
- // skip this certificate if we've already visited it,
- // and if not, mark all the names as visited
- if _, ok := visited[name]; ok {
- continue
- }
- for _, n := range cert.Names {
- visited[n] = struct{}{}
- }
-
- // no point in updating OCSP for expired certificates
- if time.Now().After(cert.NotAfter) {
- continue
- }
-
- var lastNextUpdate time.Time
- if cert.OCSP != nil {
- // start checking OCSP staple about halfway through validity period for good measure
- lastNextUpdate = cert.OCSP.NextUpdate
- refreshTime := cert.OCSP.ThisUpdate.Add(lastNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
-
- // since OCSP is already stapled, we need only check if we're in that "refresh window"
- if time.Now().Before(refreshTime) {
- continue
- }
- }
-
- err := stapleOCSP(&cert, nil)
- if err != nil {
- if cert.OCSP != nil {
- // if it was no staple before, that's fine, otherwise we should log the error
- log.Printf("[ERROR] Checking OCSP for %s: %v", name, err)
- }
- continue
- }
-
- // By this point, we've obtained the latest OCSP response.
- // If there was no staple before, or if the response is updated, make
- // sure we apply the update to all names on the certificate.
- if lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate {
- log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
- cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
- for _, n := range cert.Names {
- updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
- }
- }
- }
- certCacheMu.RUnlock()
-
- // This write lock should be brief since we have all the info we need now.
- certCacheMu.Lock()
- for name, update := range updated {
- cert := certCache[name]
- cert.OCSP = update.parsed
- cert.Certificate.OCSPStaple = update.rawBytes
- certCache[name] = cert
- }
- certCacheMu.Unlock()
-}
-
-// renewDurationBefore is how long before expiration to renew certificates.
-const renewDurationBefore = (24 * time.Hour) * 30
diff --git a/core/https/setup.go b/core/https/setup.go
deleted file mode 100644
index ec90e0284..000000000
--- a/core/https/setup.go
+++ /dev/null
@@ -1,321 +0,0 @@
-package https
-
-import (
- "bytes"
- "crypto/tls"
- "encoding/pem"
- "io/ioutil"
- "log"
- "os"
- "path/filepath"
- "strconv"
- "strings"
-
- "github.com/miekg/coredns/core/setup"
- "github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/server"
-)
-
-// Setup sets up the TLS configuration and installs certificates that
-// are specified by the user in the config file. All the automatic HTTPS
-// stuff comes later outside of this function.
-func Setup(c *setup.Controller) (middleware.Middleware, error) {
- if c.Port == "80" {
- c.TLS.Enabled = false
- log.Printf("[WARNING] TLS disabled for %s.", c.Address())
- return nil, nil
- }
- c.TLS.Enabled = true
-
- // TODO(miek): disabled for now
- return nil, nil
-
- for c.Next() {
- var certificateFile, keyFile, loadDir, maxCerts string
-
- args := c.RemainingArgs()
- switch len(args) {
- case 1:
- c.TLS.LetsEncryptEmail = args[0]
-
- // user can force-disable managed TLS this way
- if c.TLS.LetsEncryptEmail == "off" {
- c.TLS.Enabled = false
- return nil, nil
- }
- case 2:
- certificateFile = args[0]
- keyFile = args[1]
- c.TLS.Manual = true
- }
-
- // Optional block with extra parameters
- var hadBlock bool
- for c.NextBlock() {
- hadBlock = true
- switch c.Val() {
- case "protocols":
- args := c.RemainingArgs()
- if len(args) != 2 {
- return nil, c.ArgErr()
- }
- value, ok := supportedProtocols[strings.ToLower(args[0])]
- if !ok {
- return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val())
- }
- c.TLS.ProtocolMinVersion = value
- value, ok = supportedProtocols[strings.ToLower(args[1])]
- if !ok {
- return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val())
- }
- c.TLS.ProtocolMaxVersion = value
- case "ciphers":
- for c.NextArg() {
- value, ok := supportedCiphersMap[strings.ToUpper(c.Val())]
- if !ok {
- return nil, c.Errf("Wrong cipher name or cipher not supported '%s'", c.Val())
- }
- c.TLS.Ciphers = append(c.TLS.Ciphers, value)
- }
- case "clients":
- c.TLS.ClientCerts = c.RemainingArgs()
- if len(c.TLS.ClientCerts) == 0 {
- return nil, c.ArgErr()
- }
- case "load":
- c.Args(&loadDir)
- c.TLS.Manual = true
- case "max_certs":
- c.Args(&maxCerts)
- c.TLS.OnDemand = true
- default:
- return nil, c.Errf("Unknown keyword '%s'", c.Val())
- }
- }
-
- // tls requires at least one argument if a block is not opened
- if len(args) == 0 && !hadBlock {
- return nil, c.ArgErr()
- }
-
- // set certificate limit if on-demand TLS is enabled
- if maxCerts != "" {
- maxCertsNum, err := strconv.Atoi(maxCerts)
- if err != nil || maxCertsNum < 1 {
- return nil, c.Err("max_certs must be a positive integer")
- }
- if onDemandMaxIssue == 0 || int32(maxCertsNum) < onDemandMaxIssue { // keep the minimum; TODO: We have to do this because it is global; should be per-server or per-vhost...
- onDemandMaxIssue = int32(maxCertsNum)
- }
- }
-
- // don't try to load certificates unless we're supposed to
- if !c.TLS.Enabled || !c.TLS.Manual {
- continue
- }
-
- // load a single certificate and key, if specified
- if certificateFile != "" && keyFile != "" {
- err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
- if err != nil {
- return nil, c.Errf("Unable to load certificate and key files for %s: %v", c.Host, err)
- }
- log.Printf("[INFO] Successfully loaded TLS assets from %s and %s", certificateFile, keyFile)
- }
-
- // load a directory of certificates, if specified
- if loadDir != "" {
- err := loadCertsInDir(c, loadDir)
- if err != nil {
- return nil, err
- }
- }
- }
-
- setDefaultTLSParams(c.Config)
-
- return nil, nil
-}
-
-// loadCertsInDir loads all the certificates/keys in dir, as long as
-// the file ends with .pem. This method of loading certificates is
-// modeled after haproxy, which expects the certificate and key to
-// be bundled into the same file:
-// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
-//
-// This function may write to the log as it walks the directory tree.
-func loadCertsInDir(c *setup.Controller, dir string) error {
- return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
- if err != nil {
- log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
- return nil
- }
- if info.IsDir() {
- return nil
- }
- if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") {
- certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer)
- var foundKey bool // use only the first key in the file
-
- bundle, err := ioutil.ReadFile(path)
- if err != nil {
- return err
- }
-
- for {
- // Decode next block so we can see what type it is
- var derBlock *pem.Block
- derBlock, bundle = pem.Decode(bundle)
- if derBlock == nil {
- break
- }
-
- if derBlock.Type == "CERTIFICATE" {
- // Re-encode certificate as PEM, appending to certificate chain
- pem.Encode(certBuilder, derBlock)
- } else if derBlock.Type == "EC PARAMETERS" {
- // EC keys generated from openssl can be composed of two blocks:
- // parameters and key (parameter block should come first)
- if !foundKey {
- // Encode parameters
- pem.Encode(keyBuilder, derBlock)
-
- // Key must immediately follow
- derBlock, bundle = pem.Decode(bundle)
- if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" {
- return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path)
- }
- pem.Encode(keyBuilder, derBlock)
- foundKey = true
- }
- } else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") {
- // RSA key
- if !foundKey {
- pem.Encode(keyBuilder, derBlock)
- foundKey = true
- }
- } else {
- return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type)
- }
- }
-
- certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes()
- if len(certPEMBytes) == 0 {
- return c.Errf("%s: failed to parse PEM data", path)
- }
- if len(keyPEMBytes) == 0 {
- return c.Errf("%s: no private key block found", path)
- }
-
- err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
- if err != nil {
- return c.Errf("%s: failed to load cert and key for %s: %v", path, c.Host, err)
- }
- log.Printf("[INFO] Successfully loaded TLS assets from %s", path)
- }
- return nil
- })
-}
-
-// setDefaultTLSParams sets the default TLS cipher suites, protocol versions,
-// and server preferences of a server.Config if they were not previously set
-// (it does not overwrite; only fills in missing values). It will also set the
-// port to 443 if not already set, TLS is enabled, TLS is manual, and the host
-// does not equal localhost.
-func setDefaultTLSParams(c *server.Config) {
- // If no ciphers provided, use default list
- if len(c.TLS.Ciphers) == 0 {
- c.TLS.Ciphers = defaultCiphers
- }
-
- // Not a cipher suite, but still important for mitigating protocol downgrade attacks
- // (prepend since having it at end breaks http2 due to non-h2-approved suites before it)
- c.TLS.Ciphers = append([]uint16{tls.TLS_FALLBACK_SCSV}, c.TLS.Ciphers...)
-
- // Set default protocol min and max versions - must balance compatibility and security
- if c.TLS.ProtocolMinVersion == 0 {
- c.TLS.ProtocolMinVersion = tls.VersionTLS10
- }
- if c.TLS.ProtocolMaxVersion == 0 {
- c.TLS.ProtocolMaxVersion = tls.VersionTLS12
- }
-
- // Prefer server cipher suites
- c.TLS.PreferServerCipherSuites = true
-
- // Default TLS port is 443; only use if port is not manually specified,
- // TLS is enabled, and the host is not localhost
- if c.Port == "" && c.TLS.Enabled && (!c.TLS.Manual || c.TLS.OnDemand) && c.Host != "localhost" {
- c.Port = "443"
- }
-}
-
-// Map of supported protocols.
-// SSLv3 will be not supported in future release.
-// HTTP/2 only supports TLS 1.2 and higher.
-var supportedProtocols = map[string]uint16{
- "ssl3.0": tls.VersionSSL30,
- "tls1.0": tls.VersionTLS10,
- "tls1.1": tls.VersionTLS11,
- "tls1.2": tls.VersionTLS12,
-}
-
-// Map of supported ciphers, used only for parsing config.
-//
-// Note that, at time of writing, HTTP/2 blacklists 276 cipher suites,
-// including all but two of the suites below (the two GCM suites).
-// See https://http2.github.io/http2-spec/#BadCipherSuites
-//
-// TLS_FALLBACK_SCSV is not in this list because we manually ensure
-// it is always added (even though it is not technically a cipher suite).
-//
-// This map, like any map, is NOT ORDERED. Do not range over this map.
-var supportedCiphersMap = map[string]uint16{
- "ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
- "ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
- "ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
- "ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
- "ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
- "ECDHE-RSA-AES256-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
- "ECDHE-ECDSA-AES256-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
- "ECDHE-ECDSA-AES128-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
- "RSA-AES128-CBC-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
- "RSA-AES256-CBC-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
- "ECDHE-RSA-3DES-EDE-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
- "RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
-}
-
-// List of supported cipher suites in descending order of preference.
-// Ordering is very important! Getting the wrong order will break
-// mainstream clients, especially with HTTP/2.
-//
-// Note that TLS_FALLBACK_SCSV is not in this list since it is always
-// added manually.
-var supportedCiphers = []uint16{
- tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
- tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
- tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
- tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
- tls.TLS_RSA_WITH_AES_256_CBC_SHA,
- tls.TLS_RSA_WITH_AES_128_CBC_SHA,
- tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
- tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
-}
-
-// List of all the ciphers we want to use by default
-var defaultCiphers = []uint16{
- tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
- tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
- tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
- tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
- tls.TLS_RSA_WITH_AES_256_CBC_SHA,
- tls.TLS_RSA_WITH_AES_128_CBC_SHA,
-}
diff --git a/core/https/setup_test.go b/core/https/setup_test.go
deleted file mode 100644
index 7640eb524..000000000
--- a/core/https/setup_test.go
+++ /dev/null
@@ -1,226 +0,0 @@
-package https
-
-// TODO(miek): all fail
-
-/*
-func TestMain(m *testing.M) {
- // Write test certificates to disk before tests, and clean up
- // when we're done.
- err := ioutil.WriteFile(certFile, testCert, 0644)
- if err != nil {
- log.Fatal(err)
- }
- err = ioutil.WriteFile(keyFile, testKey, 0644)
- if err != nil {
- os.Remove(certFile)
- log.Fatal(err)
- }
-
- result := m.Run()
-
- os.Remove(certFile)
- os.Remove(keyFile)
- os.Exit(result)
-}
-
-func TestSetupParseBasic(t *testing.T) {
- c := setup.NewTestController(`tls ` + certFile + ` ` + keyFile + ``)
-
- _, err := Setup(c)
- if err != nil {
- t.Errorf("Expected no errors, got: %v", err)
- }
-
- // Basic checks
- if !c.TLS.Manual {
- t.Error("Expected TLS Manual=true, but was false")
- }
- if !c.TLS.Enabled {
- t.Error("Expected TLS Enabled=true, but was false")
- }
-
- // Security defaults
- if c.TLS.ProtocolMinVersion != tls.VersionTLS10 {
- t.Errorf("Expected 'tls1.0 (0x0301)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion)
- }
- if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 {
- t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", c.TLS.ProtocolMaxVersion)
- }
-
- // Cipher checks
- expectedCiphers := []uint16{
- tls.TLS_FALLBACK_SCSV,
- tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
- tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
- tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
- tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
- tls.TLS_RSA_WITH_AES_256_CBC_SHA,
- tls.TLS_RSA_WITH_AES_128_CBC_SHA,
- }
-
- // Ensure count is correct (plus one for TLS_FALLBACK_SCSV)
- if len(c.TLS.Ciphers) != len(expectedCiphers) {
- t.Errorf("Expected %v Ciphers (including TLS_FALLBACK_SCSV), got %v",
- len(expectedCiphers), len(c.TLS.Ciphers))
- }
-
- // Ensure ordering is correct
- for i, actual := range c.TLS.Ciphers {
- if actual != expectedCiphers[i] {
- t.Errorf("Expected cipher in position %d to be %0x, got %0x", i, expectedCiphers[i], actual)
- }
- }
-
- if !c.TLS.PreferServerCipherSuites {
- t.Error("Expected PreferServerCipherSuites = true, but was false")
- }
-}
-
-func TestSetupParseIncompleteParams(t *testing.T) {
- // Using tls without args is an error because it's unnecessary.
- c := setup.NewTestController(`tls`)
- _, err := Setup(c)
- if err == nil {
- t.Error("Expected an error, but didn't get one")
- }
-}
-
-func TestSetupParseWithOptionalParams(t *testing.T) {
- params := `tls ` + certFile + ` ` + keyFile + ` {
- protocols ssl3.0 tls1.2
- ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
- }`
- c := setup.NewTestController(params)
-
- _, err := Setup(c)
- if err != nil {
- t.Errorf("Expected no errors, got: %v", err)
- }
-
- if c.TLS.ProtocolMinVersion != tls.VersionSSL30 {
- t.Errorf("Expected 'ssl3.0 (0x0300)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion)
- }
-
- if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 {
- t.Errorf("Expected 'tls1.2 (0x0302)' as ProtocolMaxVersion, got %#v", c.TLS.ProtocolMaxVersion)
- }
-
- if len(c.TLS.Ciphers)-1 != 3 {
- t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1)
- }
-}
-
-func TestSetupDefaultWithOptionalParams(t *testing.T) {
- params := `tls {
- ciphers RSA-3DES-EDE-CBC-SHA
- }`
- c := setup.NewTestController(params)
-
- _, err := Setup(c)
- if err != nil {
- t.Errorf("Expected no errors, got: %v", err)
- }
- if len(c.TLS.Ciphers)-1 != 1 {
- t.Errorf("Expected 1 ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1)
- }
-}
-
-// TODO: If we allow this... but probably not a good idea.
-// func TestSetupDisableHTTPRedirect(t *testing.T) {
-// c := NewTestController(`tls {
-// allow_http
-// }`)
-// _, err := TLS(c)
-// if err != nil {
-// t.Errorf("Expected no error, but got %v", err)
-// }
-// if !c.TLS.DisableHTTPRedir {
-// t.Error("Expected HTTP redirect to be disabled, but it wasn't")
-// }
-// }
-
-func TestSetupParseWithWrongOptionalParams(t *testing.T) {
- // Test protocols wrong params
- params := `tls ` + certFile + ` ` + keyFile + ` {
- protocols ssl tls
- }`
- c := setup.NewTestController(params)
- _, err := Setup(c)
- if err == nil {
- t.Errorf("Expected errors, but no error returned")
- }
-
- // Test ciphers wrong params
- params = `tls ` + certFile + ` ` + keyFile + ` {
- ciphers not-valid-cipher
- }`
- c = setup.NewTestController(params)
- _, err = Setup(c)
- if err == nil {
- t.Errorf("Expected errors, but no error returned")
- }
-}
-
-func TestSetupParseWithClientAuth(t *testing.T) {
- params := `tls ` + certFile + ` ` + keyFile + ` {
- clients client_ca.crt client2_ca.crt
- }`
- c := setup.NewTestController(params)
- _, err := Setup(c)
- if err != nil {
- t.Errorf("Expected no errors, got: %v", err)
- }
-
- if count := len(c.TLS.ClientCerts); count != 2 {
- t.Fatalf("Expected two client certs, had %d", count)
- }
- if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" {
- t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual)
- }
- if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" {
- t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual)
- }
-
- // Test missing client cert file
- params = `tls ` + certFile + ` ` + keyFile + ` {
- clients
- }`
- c = setup.NewTestController(params)
- _, err = Setup(c)
- if err == nil {
- t.Errorf("Expected an error, but no error returned")
- }
-}
-
-const (
- certFile = "test_cert.pem"
- keyFile = "test_key.pem"
-)
-
-var testCert = []byte(`-----BEGIN CERTIFICATE-----
-MIIBkjCCATmgAwIBAgIJANfFCBcABL6LMAkGByqGSM49BAEwFDESMBAGA1UEAxMJ
-bG9jYWxob3N0MB4XDTE2MDIxMDIyMjAyNFoXDTE4MDIwOTIyMjAyNFowFDESMBAG
-A1UEAxMJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEs22MtnG7
-9K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDLSiVQvFZ6lUszTlczNxVk
-pEfqrM6xAupB7qN1MHMwHQYDVR0OBBYEFHxYDvAxUwL4XrjPev6qZ/BiLDs5MEQG
-A1UdIwQ9MDuAFHxYDvAxUwL4XrjPev6qZ/BiLDs5oRikFjAUMRIwEAYDVQQDEwls
-b2NhbGhvc3SCCQDXxQgXAAS+izAMBgNVHRMEBTADAQH/MAkGByqGSM49BAEDSAAw
-RQIgRvBqbyJM2JCJqhA1FmcoZjeMocmhxQHTt1c+1N2wFUgCIQDtvrivbBPA688N
-Qh3sMeAKNKPsx5NxYdoWuu9KWcKz9A==
------END CERTIFICATE-----
-`)
-
-var testKey = []byte(`-----BEGIN EC PARAMETERS-----
-BggqhkjOPQMBBw==
------END EC PARAMETERS-----
------BEGIN EC PRIVATE KEY-----
-MHcCAQEEIGLtRmwzYVcrH3J0BnzYbGPdWVF10i9p6mxkA4+b2fURoAoGCCqGSM49
-AwEHoUQDQgAEs22MtnG79K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDL
-SiVQvFZ6lUszTlczNxVkpEfqrM6xAupB7g==
------END EC PRIVATE KEY-----
-`)
-*/
diff --git a/core/https/storage.go b/core/https/storage.go
deleted file mode 100644
index 5d8e949da..000000000
--- a/core/https/storage.go
+++ /dev/null
@@ -1,94 +0,0 @@
-package https
-
-import (
- "path/filepath"
- "strings"
-
- "github.com/miekg/coredns/core/assets"
-)
-
-// storage is used to get file paths in a consistent,
-// cross-platform way for persisting Let's Encrypt assets
-// on the file system.
-var storage = Storage(filepath.Join(assets.Path(), "letsencrypt"))
-
-// Storage is a root directory and facilitates
-// forming file paths derived from it.
-type Storage string
-
-// Sites gets the directory that stores site certificate and keys.
-func (s Storage) Sites() string {
- return filepath.Join(string(s), "sites")
-}
-
-// Site returns the path to the folder containing assets for domain.
-func (s Storage) Site(domain string) string {
- return filepath.Join(s.Sites(), domain)
-}
-
-// SiteCertFile returns the path to the certificate file for domain.
-func (s Storage) SiteCertFile(domain string) string {
- return filepath.Join(s.Site(domain), domain+".crt")
-}
-
-// SiteKeyFile returns the path to domain's private key file.
-func (s Storage) SiteKeyFile(domain string) string {
- return filepath.Join(s.Site(domain), domain+".key")
-}
-
-// SiteMetaFile returns the path to the domain's asset metadata file.
-func (s Storage) SiteMetaFile(domain string) string {
- return filepath.Join(s.Site(domain), domain+".json")
-}
-
-// Users gets the directory that stores account folders.
-func (s Storage) Users() string {
- return filepath.Join(string(s), "users")
-}
-
-// User gets the account folder for the user with email.
-func (s Storage) User(email string) string {
- if email == "" {
- email = emptyEmail
- }
- return filepath.Join(s.Users(), email)
-}
-
-// UserRegFile gets the path to the registration file for
-// the user with the given email address.
-func (s Storage) UserRegFile(email string) string {
- if email == "" {
- email = emptyEmail
- }
- fileName := emailUsername(email)
- if fileName == "" {
- fileName = "registration"
- }
- return filepath.Join(s.User(email), fileName+".json")
-}
-
-// UserKeyFile gets the path to the private key file for
-// the user with the given email address.
-func (s Storage) UserKeyFile(email string) string {
- if email == "" {
- email = emptyEmail
- }
- fileName := emailUsername(email)
- if fileName == "" {
- fileName = "private"
- }
- return filepath.Join(s.User(email), fileName+".key")
-}
-
-// emailUsername returns the username portion of an
-// email address (part before '@') or the original
-// input if it can't find the "@" symbol.
-func emailUsername(email string) string {
- at := strings.Index(email, "@")
- if at == -1 {
- return email
- } else if at == 0 {
- return email[1:]
- }
- return email[:at]
-}
diff --git a/core/https/storage_test.go b/core/https/storage_test.go
deleted file mode 100644
index 85c2220eb..000000000
--- a/core/https/storage_test.go
+++ /dev/null
@@ -1,88 +0,0 @@
-package https
-
-import (
- "path/filepath"
- "testing"
-)
-
-func TestStorage(t *testing.T) {
- storage = Storage("./le_test")
-
- if expected, actual := filepath.Join("le_test", "sites"), storage.Sites(); actual != expected {
- t.Errorf("Expected Sites() to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "sites", "test.com"), storage.Site("test.com"); actual != expected {
- t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("test.com"); actual != expected {
- t.Errorf("Expected SiteCertFile() to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.key"), storage.SiteKeyFile("test.com"); actual != expected {
- t.Errorf("Expected SiteKeyFile() to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("test.com"); actual != expected {
- t.Errorf("Expected SiteMetaFile() to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "users"), storage.Users(); actual != expected {
- t.Errorf("Expected Users() to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "users", "me@example.com"), storage.User("me@example.com"); actual != expected {
- t.Errorf("Expected User() to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.json"), storage.UserRegFile("me@example.com"); actual != expected {
- t.Errorf("Expected UserRegFile() to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.key"), storage.UserKeyFile("me@example.com"); actual != expected {
- t.Errorf("Expected UserKeyFile() to return '%s' but got '%s'", expected, actual)
- }
-
- // Test with empty emails
- if expected, actual := filepath.Join("le_test", "users", emptyEmail), storage.User(emptyEmail); actual != expected {
- t.Errorf("Expected User(\"\") to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".json"), storage.UserRegFile(""); actual != expected {
- t.Errorf("Expected UserRegFile(\"\") to return '%s' but got '%s'", expected, actual)
- }
- if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".key"), storage.UserKeyFile(""); actual != expected {
- t.Errorf("Expected UserKeyFile(\"\") to return '%s' but got '%s'", expected, actual)
- }
-}
-
-func TestEmailUsername(t *testing.T) {
- for i, test := range []struct {
- input, expect string
- }{
- {
- input: "username@example.com",
- expect: "username",
- },
- {
- input: "plus+addressing@example.com",
- expect: "plus+addressing",
- },
- {
- input: "me+plus-addressing@example.com",
- expect: "me+plus-addressing",
- },
- {
- input: "not-an-email",
- expect: "not-an-email",
- },
- {
- input: "@foobar.com",
- expect: "foobar.com",
- },
- {
- input: emptyEmail,
- expect: emptyEmail,
- },
- {
- input: "",
- expect: "",
- },
- } {
- if actual := emailUsername(test.input); actual != test.expect {
- t.Errorf("Test %d: Expected username to be '%s' but was '%s'", i, test.expect, actual)
- }
- }
-}
diff --git a/core/https/user.go b/core/https/user.go
deleted file mode 100644
index 9c30c656c..000000000
--- a/core/https/user.go
+++ /dev/null
@@ -1,200 +0,0 @@
-package https
-
-import (
- "bufio"
- "crypto"
- "crypto/ecdsa"
- "crypto/elliptic"
- "crypto/rand"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "io/ioutil"
- "os"
- "strings"
-
- "github.com/miekg/coredns/server"
- "github.com/xenolf/lego/acme"
-)
-
-// User represents a Let's Encrypt user account.
-type User struct {
- Email string
- Registration *acme.RegistrationResource
- key crypto.PrivateKey
-}
-
-// GetEmail gets u's email.
-func (u User) GetEmail() string {
- return u.Email
-}
-
-// GetRegistration gets u's registration resource.
-func (u User) GetRegistration() *acme.RegistrationResource {
- return u.Registration
-}
-
-// GetPrivateKey gets u's private key.
-func (u User) GetPrivateKey() crypto.PrivateKey {
- return u.key
-}
-
-// getUser loads the user with the given email from disk.
-// If the user does not exist, it will create a new one,
-// but it does NOT save new users to the disk or register
-// them via ACME. It does NOT prompt the user.
-func getUser(email string) (User, error) {
- var user User
-
- // open user file
- regFile, err := os.Open(storage.UserRegFile(email))
- if err != nil {
- if os.IsNotExist(err) {
- // create a new user
- return newUser(email)
- }
- return user, err
- }
- defer regFile.Close()
-
- // load user information
- err = json.NewDecoder(regFile).Decode(&user)
- if err != nil {
- return user, err
- }
-
- // load their private key
- user.key, err = loadPrivateKey(storage.UserKeyFile(email))
- if err != nil {
- return user, err
- }
-
- return user, nil
-}
-
-// saveUser persists a user's key and account registration
-// to the file system. It does NOT register the user via ACME
-// or prompt the user.
-func saveUser(user User) error {
- // make user account folder
- err := os.MkdirAll(storage.User(user.Email), 0700)
- if err != nil {
- return err
- }
-
- // save private key file
- err = savePrivateKey(user.key, storage.UserKeyFile(user.Email))
- if err != nil {
- return err
- }
-
- // save registration file
- jsonBytes, err := json.MarshalIndent(&user, "", "\t")
- if err != nil {
- return err
- }
-
- return ioutil.WriteFile(storage.UserRegFile(user.Email), jsonBytes, 0600)
-}
-
-// newUser creates a new User for the given email address
-// with a new private key. This function does NOT save the
-// user to disk or register it via ACME. If you want to use
-// a user account that might already exist, call getUser
-// instead. It does NOT prompt the user.
-func newUser(email string) (User, error) {
- user := User{Email: email}
- privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
- if err != nil {
- return user, errors.New("error generating private key: " + err.Error())
- }
- user.key = privateKey
- return user, nil
-}
-
-// getEmail does everything it can to obtain an email
-// address from the user to use for TLS for cfg. If it
-// cannot get an email address, it returns empty string.
-// (It will warn the user of the consequences of an
-// empty email.) This function MAY prompt the user for
-// input. If userPresent is false, the operator will
-// NOT be prompted and an empty email may be returned.
-func getEmail(cfg server.Config, userPresent bool) string {
- // First try the tls directive from the Corefile
- leEmail := cfg.TLS.LetsEncryptEmail
- if leEmail == "" {
- // Then try memory (command line flag or typed by user previously)
- leEmail = DefaultEmail
- }
- if leEmail == "" {
- // Then try to get most recent user email ~/.coredns/users file
- userDirs, err := ioutil.ReadDir(storage.Users())
- if err == nil {
- var mostRecent os.FileInfo
- for _, dir := range userDirs {
- if !dir.IsDir() {
- continue
- }
- if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) {
- leEmail = dir.Name()
- DefaultEmail = leEmail // save for next time
- }
- }
- }
- }
- if leEmail == "" && userPresent {
- // Alas, we must bother the user and ask for an email address;
- // if they proceed they also agree to the SA.
- reader := bufio.NewReader(stdin)
- fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.")
- fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:")
- fmt.Println(" " + saURL) // TODO: Show current SA link
- fmt.Println("Please enter your email address so you can recover your account if needed.")
- fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.")
- fmt.Print("Email address: ")
- var err error
- leEmail, err = reader.ReadString('\n')
- if err != nil {
- return ""
- }
- leEmail = strings.TrimSpace(leEmail)
- DefaultEmail = leEmail
- Agreed = true
- }
- return leEmail
-}
-
-// promptUserAgreement prompts the user to agree to the agreement
-// at agreementURL via stdin. If the agreement has changed, then pass
-// true as the second argument. If this is the user's first time
-// agreeing, pass false. It returns whether the user agreed or not.
-func promptUserAgreement(agreementURL string, changed bool) bool {
- if changed {
- fmt.Printf("The Let's Encrypt Subscriber Agreement has changed:\n %s\n", agreementURL)
- fmt.Print("Do you agree to the new terms? (y/n): ")
- } else {
- fmt.Printf("To continue, you must agree to the Let's Encrypt Subscriber Agreement:\n %s\n", agreementURL)
- fmt.Print("Do you agree to the terms? (y/n): ")
- }
-
- reader := bufio.NewReader(stdin)
- answer, err := reader.ReadString('\n')
- if err != nil {
- return false
- }
- answer = strings.ToLower(strings.TrimSpace(answer))
-
- return answer == "y" || answer == "yes"
-}
-
-// stdin is used to read the user's input if prompted;
-// this is changed by tests during tests.
-var stdin = io.ReadWriter(os.Stdin)
-
-// The name of the folder for accounts where the email
-// address was not provided; default 'username' if you will.
-const emptyEmail = "default"
-
-// TODO: Use latest
-const saURL = "https://letsencrypt.org/documents/LE-SA-v1.0.1-July-27-2015.pdf"
diff --git a/core/https/user_test.go b/core/https/user_test.go
deleted file mode 100644
index 3e1af5007..000000000
--- a/core/https/user_test.go
+++ /dev/null
@@ -1,196 +0,0 @@
-package https
-
-import (
- "bytes"
- "crypto/rand"
- "crypto/rsa"
- "io"
- "os"
- "strings"
- "testing"
- "time"
-
- "github.com/miekg/coredns/server"
- "github.com/xenolf/lego/acme"
-)
-
-func TestUser(t *testing.T) {
- privateKey, err := rsa.GenerateKey(rand.Reader, 128)
- if err != nil {
- t.Fatalf("Could not generate test private key: %v", err)
- }
- u := User{
- Email: "me@mine.com",
- Registration: new(acme.RegistrationResource),
- key: privateKey,
- }
-
- if expected, actual := "me@mine.com", u.GetEmail(); actual != expected {
- t.Errorf("Expected email '%s' but got '%s'", expected, actual)
- }
- if u.GetRegistration() == nil {
- t.Error("Expected a registration resource, but got nil")
- }
- if expected, actual := privateKey, u.GetPrivateKey(); actual != expected {
- t.Errorf("Expected the private key at address %p but got one at %p instead ", expected, actual)
- }
-}
-
-func TestNewUser(t *testing.T) {
- email := "me@foobar.com"
- user, err := newUser(email)
- if err != nil {
- t.Fatalf("Error creating user: %v", err)
- }
- if user.key == nil {
- t.Error("Private key is nil")
- }
- if user.Email != email {
- t.Errorf("Expected email to be %s, but was %s", email, user.Email)
- }
- if user.Registration != nil {
- t.Error("New user already has a registration resource; it shouldn't")
- }
-}
-
-func TestSaveUser(t *testing.T) {
- storage = Storage("./testdata")
- defer os.RemoveAll(string(storage))
-
- email := "me@foobar.com"
- user, err := newUser(email)
- if err != nil {
- t.Fatalf("Error creating user: %v", err)
- }
-
- err = saveUser(user)
- if err != nil {
- t.Fatalf("Error saving user: %v", err)
- }
- _, err = os.Stat(storage.UserRegFile(email))
- if err != nil {
- t.Errorf("Cannot access user registration file, error: %v", err)
- }
- _, err = os.Stat(storage.UserKeyFile(email))
- if err != nil {
- t.Errorf("Cannot access user private key file, error: %v", err)
- }
-}
-
-func TestGetUserDoesNotAlreadyExist(t *testing.T) {
- storage = Storage("./testdata")
- defer os.RemoveAll(string(storage))
-
- user, err := getUser("user_does_not_exist@foobar.com")
- if err != nil {
- t.Fatalf("Error getting user: %v", err)
- }
-
- if user.key == nil {
- t.Error("Expected user to have a private key, but it was nil")
- }
-}
-
-func TestGetUserAlreadyExists(t *testing.T) {
- storage = Storage("./testdata")
- defer os.RemoveAll(string(storage))
-
- email := "me@foobar.com"
-
- // Set up test
- user, err := newUser(email)
- if err != nil {
- t.Fatalf("Error creating user: %v", err)
- }
- err = saveUser(user)
- if err != nil {
- t.Fatalf("Error saving user: %v", err)
- }
-
- // Expect to load user from disk
- user2, err := getUser(email)
- if err != nil {
- t.Fatalf("Error getting user: %v", err)
- }
-
- // Assert keys are the same
- if !PrivateKeysSame(user.key, user2.key) {
- t.Error("Expected private key to be the same after loading, but it wasn't")
- }
-
- // Assert emails are the same
- if user.Email != user2.Email {
- t.Errorf("Expected emails to be equal, but was '%s' before and '%s' after loading", user.Email, user2.Email)
- }
-}
-
-func TestGetEmail(t *testing.T) {
- // let's not clutter up the output
- origStdout := os.Stdout
- os.Stdout = nil
- defer func() { os.Stdout = origStdout }()
-
- storage = Storage("./testdata")
- defer os.RemoveAll(string(storage))
- DefaultEmail = "test2@foo.com"
-
- // Test1: Use email in config
- config := server.Config{
- TLS: server.TLSConfig{
- LetsEncryptEmail: "test1@foo.com",
- },
- }
- actual := getEmail(config, true)
- if actual != "test1@foo.com" {
- t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", "test1@foo.com", actual)
- }
-
- // Test2: Use default email from flag (or user previously typing it)
- actual = getEmail(server.Config{}, true)
- if actual != DefaultEmail {
- t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", DefaultEmail, actual)
- }
-
- // Test3: Get input from user
- DefaultEmail = ""
- stdin = new(bytes.Buffer)
- _, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n"))
- if err != nil {
- t.Fatalf("Could not simulate user input, error: %v", err)
- }
- actual = getEmail(server.Config{}, true)
- if actual != "test3@foo.com" {
- t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual)
- }
-
- // Test4: Get most recent email from before
- DefaultEmail = ""
- for i, eml := range []string{
- "test4-3@foo.com",
- "test4-2@foo.com",
- "test4-1@foo.com",
- } {
- u, err := newUser(eml)
- if err != nil {
- t.Fatalf("Error creating user %d: %v", i, err)
- }
- err = saveUser(u)
- if err != nil {
- t.Fatalf("Error saving user %d: %v", i, err)
- }
-
- // Change modified time so they're all different, so the test becomes deterministic
- f, err := os.Stat(storage.User(eml))
- if err != nil {
- t.Fatalf("Could not access user folder for '%s': %v", eml, err)
- }
- chTime := f.ModTime().Add(-(time.Duration(i) * time.Second))
- if err := os.Chtimes(storage.User(eml), chTime, chTime); err != nil {
- t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
- }
- }
- actual = getEmail(server.Config{}, true)
- if actual != "test4-3@foo.com" {
- t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual)
- }
-}
diff --git a/core/parse/dispenser.go b/core/parse/dispenser.go
deleted file mode 100644
index 08aa6e76d..000000000
--- a/core/parse/dispenser.go
+++ /dev/null
@@ -1,251 +0,0 @@
-package parse
-
-import (
- "errors"
- "fmt"
- "io"
- "strings"
-)
-
-// Dispenser is a type that dispenses tokens, similarly to a lexer,
-// except that it can do so with some notion of structure and has
-// some really convenient methods.
-type Dispenser struct {
- filename string
- tokens []token
- cursor int
- nesting int
-}
-
-// NewDispenser returns a Dispenser, ready to use for parsing the given input.
-func NewDispenser(filename string, input io.Reader) Dispenser {
- return Dispenser{
- filename: filename,
- tokens: allTokens(input),
- cursor: -1,
- }
-}
-
-// NewDispenserTokens returns a Dispenser filled with the given tokens.
-func NewDispenserTokens(filename string, tokens []token) Dispenser {
- return Dispenser{
- filename: filename,
- tokens: tokens,
- cursor: -1,
- }
-}
-
-// Next loads the next token. Returns true if a token
-// was loaded; false otherwise. If false, all tokens
-// have been consumed.
-func (d *Dispenser) Next() bool {
- if d.cursor < len(d.tokens)-1 {
- d.cursor++
- return true
- }
- return false
-}
-
-// NextArg loads the next token if it is on the same
-// line. Returns true if a token was loaded; false
-// otherwise. If false, all tokens on the line have
-// been consumed. It handles imported tokens correctly.
-func (d *Dispenser) NextArg() bool {
- if d.cursor < 0 {
- d.cursor++
- return true
- }
- if d.cursor >= len(d.tokens) {
- return false
- }
- if d.cursor < len(d.tokens)-1 &&
- d.tokens[d.cursor].file == d.tokens[d.cursor+1].file &&
- d.tokens[d.cursor].line+d.numLineBreaks(d.cursor) == d.tokens[d.cursor+1].line {
- d.cursor++
- return true
- }
- return false
-}
-
-// NextLine loads the next token only if it is not on the same
-// line as the current token, and returns true if a token was
-// loaded; false otherwise. If false, there is not another token
-// or it is on the same line. It handles imported tokens correctly.
-func (d *Dispenser) NextLine() bool {
- if d.cursor < 0 {
- d.cursor++
- return true
- }
- if d.cursor >= len(d.tokens) {
- return false
- }
- if d.cursor < len(d.tokens)-1 &&
- (d.tokens[d.cursor].file != d.tokens[d.cursor+1].file ||
- d.tokens[d.cursor].line+d.numLineBreaks(d.cursor) < d.tokens[d.cursor+1].line) {
- d.cursor++
- return true
- }
- return false
-}
-
-// NextBlock can be used as the condition of a for loop
-// to load the next token as long as it opens a block or
-// is already in a block. It returns true if a token was
-// loaded, or false when the block's closing curly brace
-// was loaded and thus the block ended. Nested blocks are
-// not supported.
-func (d *Dispenser) NextBlock() bool {
- if d.nesting > 0 {
- d.Next()
- if d.Val() == "}" {
- d.nesting--
- return false
- }
- return true
- }
- if !d.NextArg() { // block must open on same line
- return false
- }
- if d.Val() != "{" {
- d.cursor-- // roll back if not opening brace
- return false
- }
- d.Next()
- if d.Val() == "}" {
- // Open and then closed right away
- return false
- }
- d.nesting++
- return true
-}
-
-// IncrNest adds a level of nesting to the dispenser.
-func (d *Dispenser) IncrNest() {
- d.nesting++
- return
-}
-
-// Val gets the text of the current token. If there is no token
-// loaded, it returns empty string.
-func (d *Dispenser) Val() string {
- if d.cursor < 0 || d.cursor >= len(d.tokens) {
- return ""
- }
- return d.tokens[d.cursor].text
-}
-
-// Line gets the line number of the current token. If there is no token
-// loaded, it returns 0.
-func (d *Dispenser) Line() int {
- if d.cursor < 0 || d.cursor >= len(d.tokens) {
- return 0
- }
- return d.tokens[d.cursor].line
-}
-
-// File gets the filename of the current token. If there is no token loaded,
-// it returns the filename originally given when parsing started.
-func (d *Dispenser) File() string {
- if d.cursor < 0 || d.cursor >= len(d.tokens) {
- return d.filename
- }
- if tokenFilename := d.tokens[d.cursor].file; tokenFilename != "" {
- return tokenFilename
- }
- return d.filename
-}
-
-// Args is a convenience function that loads the next arguments
-// (tokens on the same line) into an arbitrary number of strings
-// pointed to in targets. If there are fewer tokens available
-// than string pointers, the remaining strings will not be changed
-// and false will be returned. If there were enough tokens available
-// to fill the arguments, then true will be returned.
-func (d *Dispenser) Args(targets ...*string) bool {
- enough := true
- for i := 0; i < len(targets); i++ {
- if !d.NextArg() {
- enough = false
- break
- }
- *targets[i] = d.Val()
- }
- return enough
-}
-
-// RemainingArgs loads any more arguments (tokens on the same line)
-// into a slice and returns them. Open curly brace tokens also indicate
-// the end of arguments, and the curly brace is not included in
-// the return value nor is it loaded.
-func (d *Dispenser) RemainingArgs() []string {
- var args []string
-
- for d.NextArg() {
- if d.Val() == "{" {
- d.cursor--
- break
- }
- args = append(args, d.Val())
- }
-
- return args
-}
-
-// ArgErr returns an argument error, meaning that another
-// argument was expected but not found. In other words,
-// a line break or open curly brace was encountered instead of
-// an argument.
-func (d *Dispenser) ArgErr() error {
- if d.Val() == "{" {
- return d.Err("Unexpected token '{', expecting argument")
- }
- return d.Errf("Wrong argument count or unexpected line ending after '%s'", d.Val())
-}
-
-// SyntaxErr creates a generic syntax error which explains what was
-// found and what was expected.
-func (d *Dispenser) SyntaxErr(expected string) error {
- msg := fmt.Sprintf("%s:%d - Syntax error: Unexpected token '%s', expecting '%s'", d.File(), d.Line(), d.Val(), expected)
- return errors.New(msg)
-}
-
-// EOFErr returns an error indicating that the dispenser reached
-// the end of the input when searching for the next token.
-func (d *Dispenser) EOFErr() error {
- return d.Errf("Unexpected EOF")
-}
-
-// Err generates a custom parse error with a message of msg.
-func (d *Dispenser) Err(msg string) error {
- msg = fmt.Sprintf("%s:%d - Parse error: %s", d.File(), d.Line(), msg)
- return errors.New(msg)
-}
-
-// Errf is like Err, but for formatted error messages
-func (d *Dispenser) Errf(format string, args ...interface{}) error {
- return d.Err(fmt.Sprintf(format, args...))
-}
-
-// numLineBreaks counts how many line breaks are in the token
-// value given by the token index tknIdx. It returns 0 if the
-// token does not exist or there are no line breaks.
-func (d *Dispenser) numLineBreaks(tknIdx int) int {
- if tknIdx < 0 || tknIdx >= len(d.tokens) {
- return 0
- }
- return strings.Count(d.tokens[tknIdx].text, "\n")
-}
-
-// isNewLine determines whether the current token is on a different
-// line (higher line number) than the previous token. It handles imported
-// tokens correctly. If there isn't a previous token, it returns true.
-func (d *Dispenser) isNewLine() bool {
- if d.cursor < 1 {
- return true
- }
- if d.cursor > len(d.tokens)-1 {
- return false
- }
- return d.tokens[d.cursor-1].file != d.tokens[d.cursor].file ||
- d.tokens[d.cursor-1].line+d.numLineBreaks(d.cursor-1) < d.tokens[d.cursor].line
-}
diff --git a/core/parse/dispenser_test.go b/core/parse/dispenser_test.go
deleted file mode 100644
index 20a7ddcac..000000000
--- a/core/parse/dispenser_test.go
+++ /dev/null
@@ -1,292 +0,0 @@
-package parse
-
-import (
- "reflect"
- "strings"
- "testing"
-)
-
-func TestDispenser_Val_Next(t *testing.T) {
- input := `host:port
- dir1 arg1
- dir2 arg2 arg3
- dir3`
- d := NewDispenser("Testfile", strings.NewReader(input))
-
- if val := d.Val(); val != "" {
- t.Fatalf("Val(): Should return empty string when no token loaded; got '%s'", val)
- }
-
- assertNext := func(shouldLoad bool, expectedCursor int, expectedVal string) {
- if loaded := d.Next(); loaded != shouldLoad {
- t.Errorf("Next(): Expected %v but got %v instead (val '%s')", shouldLoad, loaded, d.Val())
- }
- if d.cursor != expectedCursor {
- t.Errorf("Expected cursor to be %d, but was %d", expectedCursor, d.cursor)
- }
- if d.nesting != 0 {
- t.Errorf("Nesting should be 0, was %d instead", d.nesting)
- }
- if val := d.Val(); val != expectedVal {
- t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val)
- }
- }
-
- assertNext(true, 0, "host:port")
- assertNext(true, 1, "dir1")
- assertNext(true, 2, "arg1")
- assertNext(true, 3, "dir2")
- assertNext(true, 4, "arg2")
- assertNext(true, 5, "arg3")
- assertNext(true, 6, "dir3")
- // Note: This next test simply asserts existing behavior.
- // If desired, we may wish to empty the token value after
- // reading past the EOF. Open an issue if you want this change.
- assertNext(false, 6, "dir3")
-}
-
-func TestDispenser_NextArg(t *testing.T) {
- input := `dir1 arg1
- dir2 arg2 arg3
- dir3`
- d := NewDispenser("Testfile", strings.NewReader(input))
-
- assertNext := func(shouldLoad bool, expectedVal string, expectedCursor int) {
- if d.Next() != shouldLoad {
- t.Errorf("Next(): Should load token but got false instead (val: '%s')", d.Val())
- }
- if d.cursor != expectedCursor {
- t.Errorf("Next(): Expected cursor to be at %d, but it was %d", expectedCursor, d.cursor)
- }
- if val := d.Val(); val != expectedVal {
- t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val)
- }
- }
-
- assertNextArg := func(expectedVal string, loadAnother bool, expectedCursor int) {
- if d.NextArg() != true {
- t.Error("NextArg(): Should load next argument but got false instead")
- }
- if d.cursor != expectedCursor {
- t.Errorf("NextArg(): Expected cursor to be at %d, but it was %d", expectedCursor, d.cursor)
- }
- if val := d.Val(); val != expectedVal {
- t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val)
- }
- if !loadAnother {
- if d.NextArg() != false {
- t.Fatalf("NextArg(): Should NOT load another argument, but got true instead (val: '%s')", d.Val())
- }
- if d.cursor != expectedCursor {
- t.Errorf("NextArg(): Expected cursor to remain at %d, but it was %d", expectedCursor, d.cursor)
- }
- }
- }
-
- assertNext(true, "dir1", 0)
- assertNextArg("arg1", false, 1)
- assertNext(true, "dir2", 2)
- assertNextArg("arg2", true, 3)
- assertNextArg("arg3", false, 4)
- assertNext(true, "dir3", 5)
- assertNext(false, "dir3", 5)
-}
-
-func TestDispenser_NextLine(t *testing.T) {
- input := `host:port
- dir1 arg1
- dir2 arg2 arg3`
- d := NewDispenser("Testfile", strings.NewReader(input))
-
- assertNextLine := func(shouldLoad bool, expectedVal string, expectedCursor int) {
- if d.NextLine() != shouldLoad {
- t.Errorf("NextLine(): Should load token but got false instead (val: '%s')", d.Val())
- }
- if d.cursor != expectedCursor {
- t.Errorf("NextLine(): Expected cursor to be %d, instead was %d", expectedCursor, d.cursor)
- }
- if val := d.Val(); val != expectedVal {
- t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val)
- }
- }
-
- assertNextLine(true, "host:port", 0)
- assertNextLine(true, "dir1", 1)
- assertNextLine(false, "dir1", 1)
- d.Next() // arg1
- assertNextLine(true, "dir2", 3)
- assertNextLine(false, "dir2", 3)
- d.Next() // arg2
- assertNextLine(false, "arg2", 4)
- d.Next() // arg3
- assertNextLine(false, "arg3", 5)
-}
-
-func TestDispenser_NextBlock(t *testing.T) {
- input := `foobar1 {
- sub1 arg1
- sub2
- }
- foobar2 {
- }`
- d := NewDispenser("Testfile", strings.NewReader(input))
-
- assertNextBlock := func(shouldLoad bool, expectedCursor, expectedNesting int) {
- if loaded := d.NextBlock(); loaded != shouldLoad {
- t.Errorf("NextBlock(): Should return %v but got %v", shouldLoad, loaded)
- }
- if d.cursor != expectedCursor {
- t.Errorf("NextBlock(): Expected cursor to be %d, was %d", expectedCursor, d.cursor)
- }
- if d.nesting != expectedNesting {
- t.Errorf("NextBlock(): Nesting should be %d, not %d", expectedNesting, d.nesting)
- }
- }
-
- assertNextBlock(false, -1, 0)
- d.Next() // foobar1
- assertNextBlock(true, 2, 1)
- assertNextBlock(true, 3, 1)
- assertNextBlock(true, 4, 1)
- assertNextBlock(false, 5, 0)
- d.Next() // foobar2
- assertNextBlock(false, 8, 0) // empty block is as if it didn't exist
-}
-
-func TestDispenser_Args(t *testing.T) {
- var s1, s2, s3 string
- input := `dir1 arg1 arg2 arg3
- dir2 arg4 arg5
- dir3 arg6 arg7
- dir4`
- d := NewDispenser("Testfile", strings.NewReader(input))
-
- d.Next() // dir1
-
- // As many strings as arguments
- if all := d.Args(&s1, &s2, &s3); !all {
- t.Error("Args(): Expected true, got false")
- }
- if s1 != "arg1" {
- t.Errorf("Args(): Expected s1 to be 'arg1', got '%s'", s1)
- }
- if s2 != "arg2" {
- t.Errorf("Args(): Expected s2 to be 'arg2', got '%s'", s2)
- }
- if s3 != "arg3" {
- t.Errorf("Args(): Expected s3 to be 'arg3', got '%s'", s3)
- }
-
- d.Next() // dir2
-
- // More strings than arguments
- if all := d.Args(&s1, &s2, &s3); all {
- t.Error("Args(): Expected false, got true")
- }
- if s1 != "arg4" {
- t.Errorf("Args(): Expected s1 to be 'arg4', got '%s'", s1)
- }
- if s2 != "arg5" {
- t.Errorf("Args(): Expected s2 to be 'arg5', got '%s'", s2)
- }
- if s3 != "arg3" {
- t.Errorf("Args(): Expected s3 to be unchanged ('arg3'), instead got '%s'", s3)
- }
-
- // (quick cursor check just for kicks and giggles)
- if d.cursor != 6 {
- t.Errorf("Cursor should be 6, but is %d", d.cursor)
- }
-
- d.Next() // dir3
-
- // More arguments than strings
- if all := d.Args(&s1); !all {
- t.Error("Args(): Expected true, got false")
- }
- if s1 != "arg6" {
- t.Errorf("Args(): Expected s1 to be 'arg6', got '%s'", s1)
- }
-
- d.Next() // dir4
-
- // No arguments or strings
- if all := d.Args(); !all {
- t.Error("Args(): Expected true, got false")
- }
-
- // No arguments but at least one string
- if all := d.Args(&s1); all {
- t.Error("Args(): Expected false, got true")
- }
-}
-
-func TestDispenser_RemainingArgs(t *testing.T) {
- input := `dir1 arg1 arg2 arg3
- dir2 arg4 arg5
- dir3 arg6 { arg7
- dir4`
- d := NewDispenser("Testfile", strings.NewReader(input))
-
- d.Next() // dir1
-
- args := d.RemainingArgs()
- if expected := []string{"arg1", "arg2", "arg3"}; !reflect.DeepEqual(args, expected) {
- t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args)
- }
-
- d.Next() // dir2
-
- args = d.RemainingArgs()
- if expected := []string{"arg4", "arg5"}; !reflect.DeepEqual(args, expected) {
- t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args)
- }
-
- d.Next() // dir3
-
- args = d.RemainingArgs()
- if expected := []string{"arg6"}; !reflect.DeepEqual(args, expected) {
- t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args)
- }
-
- d.Next() // {
- d.Next() // arg7
- d.Next() // dir4
-
- args = d.RemainingArgs()
- if len(args) != 0 {
- t.Errorf("RemainingArgs(): Expected %v, got %v", []string{}, args)
- }
-}
-
-func TestDispenser_ArgErr_Err(t *testing.T) {
- input := `dir1 {
- }
- dir2 arg1 arg2`
- d := NewDispenser("Testfile", strings.NewReader(input))
-
- d.cursor = 1 // {
-
- if err := d.ArgErr(); err == nil || !strings.Contains(err.Error(), "{") {
- t.Errorf("ArgErr(): Expected an error message with { in it, but got '%v'", err)
- }
-
- d.cursor = 5 // arg2
-
- if err := d.ArgErr(); err == nil || !strings.Contains(err.Error(), "arg2") {
- t.Errorf("ArgErr(): Expected an error message with 'arg2' in it; got '%v'", err)
- }
-
- err := d.Err("foobar")
- if err == nil {
- t.Fatalf("Err(): Expected an error, got nil")
- }
-
- if !strings.Contains(err.Error(), "Testfile:3") {
- t.Errorf("Expected error message with filename:line in it; got '%v'", err)
- }
-
- if !strings.Contains(err.Error(), "foobar") {
- t.Errorf("Expected error message with custom message in it ('foobar'); got '%v'", err)
- }
-}
diff --git a/core/parse/import_glob0.txt b/core/parse/import_glob0.txt
deleted file mode 100644
index e610b5e7c..000000000
--- a/core/parse/import_glob0.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-glob0.host0 {
- dir2 arg1
-}
-
-glob0.host1 {
-}
diff --git a/core/parse/import_glob1.txt b/core/parse/import_glob1.txt
deleted file mode 100644
index 111eb044d..000000000
--- a/core/parse/import_glob1.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-glob1.host0 {
- dir1
- dir2 arg1
-}
diff --git a/core/parse/import_glob2.txt b/core/parse/import_glob2.txt
deleted file mode 100644
index c09f784ec..000000000
--- a/core/parse/import_glob2.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-glob2.host0 {
- dir2 arg1
-}
diff --git a/core/parse/import_test1.txt b/core/parse/import_test1.txt
deleted file mode 100644
index dac7b29be..000000000
--- a/core/parse/import_test1.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-dir2 arg1 arg2
-dir3
\ No newline at end of file
diff --git a/core/parse/import_test2.txt b/core/parse/import_test2.txt
deleted file mode 100644
index 140c87939..000000000
--- a/core/parse/import_test2.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-host1 {
- dir1
- dir2 arg1
-}
\ No newline at end of file
diff --git a/core/parse/lexer.go b/core/parse/lexer.go
deleted file mode 100644
index d2939eba2..000000000
--- a/core/parse/lexer.go
+++ /dev/null
@@ -1,122 +0,0 @@
-package parse
-
-import (
- "bufio"
- "io"
- "unicode"
-)
-
-type (
- // lexer is a utility which can get values, token by
- // token, from a Reader. A token is a word, and tokens
- // are separated by whitespace. A word can be enclosed
- // in quotes if it contains whitespace.
- lexer struct {
- reader *bufio.Reader
- token token
- line int
- }
-
- // token represents a single parsable unit.
- token struct {
- file string
- line int
- text string
- }
-)
-
-// load prepares the lexer to scan an input for tokens.
-func (l *lexer) load(input io.Reader) error {
- l.reader = bufio.NewReader(input)
- l.line = 1
- return nil
-}
-
-// next loads the next token into the lexer.
-// A token is delimited by whitespace, unless
-// the token starts with a quotes character (")
-// in which case the token goes until the closing
-// quotes (the enclosing quotes are not included).
-// Inside quoted strings, quotes may be escaped
-// with a preceding \ character. No other chars
-// may be escaped. The rest of the line is skipped
-// if a "#" character is read in. Returns true if
-// a token was loaded; false otherwise.
-func (l *lexer) next() bool {
- var val []rune
- var comment, quoted, escaped bool
-
- makeToken := func() bool {
- l.token.text = string(val)
- return true
- }
-
- for {
- ch, _, err := l.reader.ReadRune()
- if err != nil {
- if len(val) > 0 {
- return makeToken()
- }
- if err == io.EOF {
- return false
- }
- panic(err)
- }
-
- if quoted {
- if !escaped {
- if ch == '\\' {
- escaped = true
- continue
- } else if ch == '"' {
- quoted = false
- return makeToken()
- }
- }
- if ch == '\n' {
- l.line++
- }
- if escaped {
- // only escape quotes
- if ch != '"' {
- val = append(val, '\\')
- }
- }
- val = append(val, ch)
- escaped = false
- continue
- }
-
- if unicode.IsSpace(ch) {
- if ch == '\r' {
- continue
- }
- if ch == '\n' {
- l.line++
- comment = false
- }
- if len(val) > 0 {
- return makeToken()
- }
- continue
- }
-
- if ch == '#' {
- comment = true
- }
-
- if comment {
- continue
- }
-
- if len(val) == 0 {
- l.token = token{line: l.line}
- if ch == '"' {
- quoted = true
- continue
- }
- }
-
- val = append(val, ch)
- }
-}
diff --git a/core/parse/lexer_test.go b/core/parse/lexer_test.go
deleted file mode 100644
index f12c7e7dc..000000000
--- a/core/parse/lexer_test.go
+++ /dev/null
@@ -1,165 +0,0 @@
-package parse
-
-import (
- "strings"
- "testing"
-)
-
-type lexerTestCase struct {
- input string
- expected []token
-}
-
-func TestLexer(t *testing.T) {
- testCases := []lexerTestCase{
- {
- input: `host:123`,
- expected: []token{
- {line: 1, text: "host:123"},
- },
- },
- {
- input: `host:123
-
- directive`,
- expected: []token{
- {line: 1, text: "host:123"},
- {line: 3, text: "directive"},
- },
- },
- {
- input: `host:123 {
- directive
- }`,
- expected: []token{
- {line: 1, text: "host:123"},
- {line: 1, text: "{"},
- {line: 2, text: "directive"},
- {line: 3, text: "}"},
- },
- },
- {
- input: `host:123 { directive }`,
- expected: []token{
- {line: 1, text: "host:123"},
- {line: 1, text: "{"},
- {line: 1, text: "directive"},
- {line: 1, text: "}"},
- },
- },
- {
- input: `host:123 {
- #comment
- directive
- # comment
- foobar # another comment
- }`,
- expected: []token{
- {line: 1, text: "host:123"},
- {line: 1, text: "{"},
- {line: 3, text: "directive"},
- {line: 5, text: "foobar"},
- {line: 6, text: "}"},
- },
- },
- {
- input: `a "quoted value" b
- foobar`,
- expected: []token{
- {line: 1, text: "a"},
- {line: 1, text: "quoted value"},
- {line: 1, text: "b"},
- {line: 2, text: "foobar"},
- },
- },
- {
- input: `A "quoted \"value\" inside" B`,
- expected: []token{
- {line: 1, text: "A"},
- {line: 1, text: `quoted "value" inside`},
- {line: 1, text: "B"},
- },
- },
- {
- input: `"don't\escape"`,
- expected: []token{
- {line: 1, text: `don't\escape`},
- },
- },
- {
- input: `"don't\\escape"`,
- expected: []token{
- {line: 1, text: `don't\\escape`},
- },
- },
- {
- input: `A "quoted value with line
- break inside" {
- foobar
- }`,
- expected: []token{
- {line: 1, text: "A"},
- {line: 1, text: "quoted value with line\n\t\t\t\t\tbreak inside"},
- {line: 2, text: "{"},
- {line: 3, text: "foobar"},
- {line: 4, text: "}"},
- },
- },
- {
- input: `"C:\php\php-cgi.exe"`,
- expected: []token{
- {line: 1, text: `C:\php\php-cgi.exe`},
- },
- },
- {
- input: `empty "" string`,
- expected: []token{
- {line: 1, text: `empty`},
- {line: 1, text: ``},
- {line: 1, text: `string`},
- },
- },
- {
- input: "skip those\r\nCR characters",
- expected: []token{
- {line: 1, text: "skip"},
- {line: 1, text: "those"},
- {line: 2, text: "CR"},
- {line: 2, text: "characters"},
- },
- },
- }
-
- for i, testCase := range testCases {
- actual := tokenize(testCase.input)
- lexerCompare(t, i, testCase.expected, actual)
- }
-}
-
-func tokenize(input string) (tokens []token) {
- l := lexer{}
- l.load(strings.NewReader(input))
- for l.next() {
- tokens = append(tokens, l.token)
- }
- return
-}
-
-func lexerCompare(t *testing.T, n int, expected, actual []token) {
- if len(expected) != len(actual) {
- t.Errorf("Test case %d: expected %d token(s) but got %d", n, len(expected), len(actual))
- }
-
- for i := 0; i < len(actual) && i < len(expected); i++ {
- if actual[i].line != expected[i].line {
- t.Errorf("Test case %d token %d ('%s'): expected line %d but was line %d",
- n, i, expected[i].text, expected[i].line, actual[i].line)
- break
- }
- if actual[i].text != expected[i].text {
- t.Errorf("Test case %d token %d: expected text '%s' but was '%s'",
- n, i, expected[i].text, actual[i].text)
- break
- }
- }
-}
diff --git a/core/parse/parse.go b/core/parse/parse.go
deleted file mode 100644
index faef36c28..000000000
--- a/core/parse/parse.go
+++ /dev/null
@@ -1,32 +0,0 @@
-// Package parse provides facilities for parsing configuration files.
-package parse
-
-import "io"
-
-// ServerBlocks parses the input just enough to organize tokens,
-// in order, by server block. No further parsing is performed.
-// If checkDirectives is true, only valid directives will be allowed
-// otherwise we consider it a parse error. Server blocks are returned
-// in the order in which they appear.
-func ServerBlocks(filename string, input io.Reader, checkDirectives bool) ([]ServerBlock, error) {
- p := parser{Dispenser: NewDispenser(filename, input)}
- p.checkDirectives = checkDirectives
- blocks, err := p.parseAll()
- return blocks, err
-}
-
-// allTokens lexes the entire input, but does not parse it.
-// It returns all the tokens from the input, unstructured
-// and in order.
-func allTokens(input io.Reader) (tokens []token) {
- l := new(lexer)
- l.load(input)
- for l.next() {
- tokens = append(tokens, l.token)
- }
- return
-}
-
-// ValidDirectives is a set of directives that are valid (unordered). Populated
-// by config package's init function.
-var ValidDirectives = make(map[string]struct{})
diff --git a/core/parse/parse_test.go b/core/parse/parse_test.go
deleted file mode 100644
index 48746300f..000000000
--- a/core/parse/parse_test.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package parse
-
-import (
- "strings"
- "testing"
-)
-
-func TestAllTokens(t *testing.T) {
- input := strings.NewReader("a b c\nd e")
- expected := []string{"a", "b", "c", "d", "e"}
- tokens := allTokens(input)
-
- if len(tokens) != len(expected) {
- t.Fatalf("Expected %d tokens, got %d", len(expected), len(tokens))
- }
-
- for i, val := range expected {
- if tokens[i].text != val {
- t.Errorf("Token %d should be '%s' but was '%s'", i, val, tokens[i].text)
- }
- }
-}
diff --git a/core/parse/parsing.go b/core/parse/parsing.go
deleted file mode 100644
index 7be7d6714..000000000
--- a/core/parse/parsing.go
+++ /dev/null
@@ -1,388 +0,0 @@
-package parse
-
-import (
- "fmt"
- "net"
- "os"
- "path/filepath"
- "strings"
-
- "github.com/miekg/dns"
-)
-
-type parser struct {
- Dispenser
- block ServerBlock // current server block being parsed
- eof bool // if we encounter a valid EOF in a hard place
- checkDirectives bool // if true, directives must be known
-}
-
-func (p *parser) parseAll() ([]ServerBlock, error) {
- var blocks []ServerBlock
-
- for p.Next() {
- err := p.parseOne()
- if err != nil {
- return blocks, err
- }
- if len(p.block.Addresses) > 0 {
- blocks = append(blocks, p.block)
- }
- }
-
- return blocks, nil
-}
-
-func (p *parser) parseOne() error {
- p.block = ServerBlock{Tokens: make(map[string][]token)}
-
- err := p.begin()
- if err != nil {
- return err
- }
-
- return nil
-}
-
-func (p *parser) begin() error {
- if len(p.tokens) == 0 {
- return nil
- }
-
- err := p.addresses()
- if err != nil {
- return err
- }
-
- if p.eof {
- // this happens if the Corefile consists of only
- // a line of addresses and nothing else
- return nil
- }
-
- err = p.blockContents()
- if err != nil {
- return err
- }
-
- return nil
-}
-
-func (p *parser) addresses() error {
- var expectingAnother bool
-
- for {
- tkn := replaceEnvVars(p.Val())
-
- // special case: import directive replaces tokens during parse-time
- if tkn == "import" && p.isNewLine() {
- err := p.doImport()
- if err != nil {
- return err
- }
- continue
- }
-
- // Open brace definitely indicates end of addresses
- if tkn == "{" {
- if expectingAnother {
- return p.Errf("Expected another address but had '%s' - check for extra comma", tkn)
- }
- break
- }
-
- if tkn != "" { // empty token possible if user typed "" in Corefile
- // Trailing comma indicates another address will follow, which
- // may possibly be on the next line
- if tkn[len(tkn)-1] == ',' {
- tkn = tkn[:len(tkn)-1]
- expectingAnother = true
- } else {
- expectingAnother = false // but we may still see another one on this line
- }
-
- // Parse and save this address
- addr, err := standardAddress(tkn)
- if err != nil {
- return err
- }
- p.block.Addresses = append(p.block.Addresses, addr)
- }
-
- // Advance token and possibly break out of loop or return error
- hasNext := p.Next()
- if expectingAnother && !hasNext {
- return p.EOFErr()
- }
- if !hasNext {
- p.eof = true
- break // EOF
- }
- if !expectingAnother && p.isNewLine() {
- break
- }
- }
-
- return nil
-}
-
-func (p *parser) blockContents() error {
- errOpenCurlyBrace := p.openCurlyBrace()
- if errOpenCurlyBrace != nil {
- // single-server configs don't need curly braces
- p.cursor--
- }
-
- err := p.directives()
- if err != nil {
- return err
- }
-
- // Only look for close curly brace if there was an opening
- if errOpenCurlyBrace == nil {
- err = p.closeCurlyBrace()
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-// directives parses through all the lines for directives
-// and it expects the next token to be the first
-// directive. It goes until EOF or closing curly brace
-// which ends the server block.
-func (p *parser) directives() error {
- for p.Next() {
- // end of server block
- if p.Val() == "}" {
- break
- }
-
- // special case: import directive replaces tokens during parse-time
- if p.Val() == "import" {
- err := p.doImport()
- if err != nil {
- return err
- }
- p.cursor-- // cursor is advanced when we continue, so roll back one more
- continue
- }
-
- // normal case: parse a directive on this line
- if err := p.directive(); err != nil {
- return err
- }
- }
- return nil
-}
-
-// doImport swaps out the import directive and its argument
-// (a total of 2 tokens) with the tokens in the specified file
-// or globbing pattern. When the function returns, the cursor
-// is on the token before where the import directive was. In
-// other words, call Next() to access the first token that was
-// imported.
-func (p *parser) doImport() error {
- // syntax check
- if !p.NextArg() {
- return p.ArgErr()
- }
- importPattern := p.Val()
- if p.NextArg() {
- return p.Err("Import takes only one argument (glob pattern or file)")
- }
-
- // do glob
- matches, err := filepath.Glob(importPattern)
- if err != nil {
- return p.Errf("Failed to use import pattern %s: %v", importPattern, err)
- }
- if len(matches) == 0 {
- return p.Errf("No files matching import pattern %s", importPattern)
- }
-
- // splice out the import directive and its argument (2 tokens total)
- tokensBefore := p.tokens[:p.cursor-1]
- tokensAfter := p.tokens[p.cursor+1:]
-
- // collect all the imported tokens
- var importedTokens []token
- for _, importFile := range matches {
- newTokens, err := p.doSingleImport(importFile)
- if err != nil {
- return err
- }
- importedTokens = append(importedTokens, newTokens...)
- }
-
- // splice the imported tokens in the place of the import statement
- // and rewind cursor so Next() will land on first imported token
- p.tokens = append(tokensBefore, append(importedTokens, tokensAfter...)...)
- p.cursor--
-
- return nil
-}
-
-// doSingleImport lexes the individual file at importFile and returns
-// its tokens or an error, if any.
-func (p *parser) doSingleImport(importFile string) ([]token, error) {
- file, err := os.Open(importFile)
- if err != nil {
- return nil, p.Errf("Could not import %s: %v", importFile, err)
- }
- defer file.Close()
- importedTokens := allTokens(file)
-
- // Tack the filename onto these tokens so errors show the imported file's name
- filename := filepath.Base(importFile)
- for i := 0; i < len(importedTokens); i++ {
- importedTokens[i].file = filename
- }
-
- return importedTokens, nil
-}
-
-// directive collects tokens until the directive's scope
-// closes (either end of line or end of curly brace block).
-// It expects the currently-loaded token to be a directive
-// (or } that ends a server block). The collected tokens
-// are loaded into the current server block for later use
-// by directive setup functions.
-func (p *parser) directive() error {
- dir := p.Val()
- nesting := 0
-
- if p.checkDirectives {
- if _, ok := ValidDirectives[dir]; !ok {
- return p.Errf("Unknown directive '%s'", dir)
- }
- }
-
- // The directive itself is appended as a relevant token
- p.block.Tokens[dir] = append(p.block.Tokens[dir], p.tokens[p.cursor])
-
- for p.Next() {
- if p.Val() == "{" {
- nesting++
- } else if p.isNewLine() && nesting == 0 {
- p.cursor-- // read too far
- break
- } else if p.Val() == "}" && nesting > 0 {
- nesting--
- } else if p.Val() == "}" && nesting == 0 {
- return p.Err("Unexpected '}' because no matching opening brace")
- }
- p.tokens[p.cursor].text = replaceEnvVars(p.tokens[p.cursor].text)
- p.block.Tokens[dir] = append(p.block.Tokens[dir], p.tokens[p.cursor])
- }
-
- if nesting > 0 {
- return p.EOFErr()
- }
- return nil
-}
-
-// openCurlyBrace expects the current token to be an
-// opening curly brace. This acts like an assertion
-// because it returns an error if the token is not
-// a opening curly brace. It does NOT advance the token.
-func (p *parser) openCurlyBrace() error {
- if p.Val() != "{" {
- return p.SyntaxErr("{")
- }
- return nil
-}
-
-// closeCurlyBrace expects the current token to be
-// a closing curly brace. This acts like an assertion
-// because it returns an error if the token is not
-// a closing curly brace. It does NOT advance the token.
-func (p *parser) closeCurlyBrace() error {
- if p.Val() != "}" {
- return p.SyntaxErr("}")
- }
- return nil
-}
-
-// standardAddress parses an address string into a structured format with separate
-// host, and port portions, as well as the original input string.
-func standardAddress(str string) (address, error) {
- var err error
-
- // first check for scheme and strip it off
- input := str
-
- // separate host and port
- host, port, err := net.SplitHostPort(str)
- if err != nil {
- host, port, err = net.SplitHostPort(str + ":")
- // no error check here; return err at end of function
- }
-
- if len(host) > 255 {
- return address{}, fmt.Errorf("specified address is too long: %d > 255", len(host))
- }
- _, d := dns.IsDomainName(host)
- if !d {
- return address{}, fmt.Errorf("host is not a valid domain: %s", host)
- }
-
- // see if we can set port based off scheme
- if port == "" {
- port = "53"
- }
-
- return address{Original: input, Host: strings.ToLower(dns.Fqdn(host)), Port: port}, err
-}
-
-// replaceEnvVars replaces environment variables that appear in the token
-// and understands both the $UNIX and %WINDOWS% syntaxes.
-func replaceEnvVars(s string) string {
- s = replaceEnvReferences(s, "{%", "%}")
- s = replaceEnvReferences(s, "{$", "}")
- return s
-}
-
-// replaceEnvReferences performs the actual replacement of env variables
-// in s, given the placeholder start and placeholder end strings.
-func replaceEnvReferences(s, refStart, refEnd string) string {
- index := strings.Index(s, refStart)
- for index != -1 {
- endIndex := strings.Index(s, refEnd)
- if endIndex != -1 {
- ref := s[index : endIndex+len(refEnd)]
- s = strings.Replace(s, ref, os.Getenv(ref[len(refStart):len(ref)-len(refEnd)]), -1)
- } else {
- return s
- }
- index = strings.Index(s, refStart)
- }
- return s
-}
-
-type (
- // ServerBlock associates tokens with a list of addresses
- // and groups tokens by directive name.
- ServerBlock struct {
- Addresses []address
- Tokens map[string][]token
- }
-
- address struct {
- Original, Host, Port string
- }
-)
-
-// HostList converts the list of addresses that are
-// associated with this server block into a slice of
-// strings, where each address is as it was originally
-// read from the input.
-func (sb ServerBlock) HostList() []string {
- sbHosts := make([]string, len(sb.Addresses))
- for j, addr := range sb.Addresses {
- sbHosts[j] = addr.Original
- }
- return sbHosts
-}
diff --git a/core/parse/parsing_test.go b/core/parse/parsing_test.go
deleted file mode 100644
index d06fc8b58..000000000
--- a/core/parse/parsing_test.go
+++ /dev/null
@@ -1,401 +0,0 @@
-package parse
-
-import (
- "os"
- "strings"
- "testing"
-)
-
-func TestStandardAddress(t *testing.T) {
- for i, test := range []struct {
- input string
- host, port string
- shouldErr bool
- }{
- {`localhost`, "localhost.", "53", false},
- {`localhost:1234`, "localhost.", "1234", false},
- {`localhost:`, "localhost.", "53", false},
- {`0.0.0.0`, "0.0.0.0.", "53", false},
- {`127.0.0.1:1234`, "127.0.0.1.", "1234", false},
- {`:1234`, ".", "1234", false},
- {`[::1]`, "::1.", "53", false},
- {`[::1]:1234`, "::1.", "1234", false},
- {`:`, ".", "53", false},
- {`localhost:http`, "localhost.", "http", false},
- {`localhost:https`, "localhost.", "https", false},
- {``, ".", "53", false},
- {`::1`, "::1.", "53", true},
- {`localhost::`, "localhost::.", "53", true},
- {`#$%@`, "#$%@.", "53", true},
- } {
- actual, err := standardAddress(test.input)
-
- if err != nil && !test.shouldErr {
- t.Errorf("Test %d (%s): Expected no error, but had error: %v", i, test.input, err)
- }
- if err == nil && test.shouldErr {
- t.Errorf("Test %d (%s): Expected error, but had none", i, test.input)
- }
-
- if actual.Host != test.host {
- t.Errorf("Test %d (%s): Expected host '%s', got '%s'", i, test.input, test.host, actual.Host)
- }
- if actual.Port != test.port {
- t.Errorf("Test %d (%s): Expected port '%s', got '%s'", i, test.input, test.port, actual.Port)
- }
- }
-}
-
-func TestParseOneAndImport(t *testing.T) {
- setupParseTests()
-
- testParseOne := func(input string) (ServerBlock, error) {
- p := testParser(input)
- p.Next() // parseOne doesn't call Next() to start, so we must
- err := p.parseOne()
- return p.block, err
- }
-
- for i, test := range []struct {
- input string
- shouldErr bool
- addresses []address
- tokens map[string]int // map of directive name to number of tokens expected
- }{
- {`localhost`, false, []address{
- {"localhost", "localhost.", "53"},
- }, map[string]int{}},
-
- {`localhost
- dir1`, false, []address{
- {"localhost", "localhost.", "53"},
- }, map[string]int{
- "dir1": 1,
- }},
-
- {`localhost:1234
- dir1 foo bar`, false, []address{
- {"localhost:1234", "localhost.", "1234"},
- }, map[string]int{
- "dir1": 3,
- }},
-
- {`localhost {
- dir1
- }`, false, []address{
- {"localhost", "localhost.", "53"},
- }, map[string]int{
- "dir1": 1,
- }},
-
- {`localhost:1234 {
- dir1 foo bar
- dir2
- }`, false, []address{
- {"localhost:1234", "localhost.", "1234"},
- }, map[string]int{
- "dir1": 3,
- "dir2": 1,
- }},
-
- {`host1:80, host2.com
- dir1 foo bar
- dir2 baz`, false, []address{
- {"host1:80", "host1.", "80"},
- {"host2.com", "host2.com.", "53"},
- }, map[string]int{
- "dir1": 3,
- "dir2": 2,
- }},
-
- {`127.0.0.1
- dir1 {
- bar baz
- }
- dir2 {
- foo bar
- }`, false, []address{
- {"127.0.0.1", "127.0.0.1.", "53"},
- }, map[string]int{
- "dir1": 5,
- "dir2": 5,
- }},
-
- {`127.0.0.1
- unknown_directive`, true, []address{
- {"127.0.0.1", "127.0.0.1.", "53"},
- }, map[string]int{}},
-
- {`localhost
- dir1 {
- foo`, true, []address{
- {"localhost", "localhost.", "53"},
- }, map[string]int{
- "dir1": 3,
- }},
-
- {`localhost
- dir1 {
- }`, false, []address{
- {"localhost", "localhost.", "53"},
- }, map[string]int{
- "dir1": 3,
- }},
-
- {`localhost
- dir1 {
- } }`, true, []address{
- {"localhost", "localhost.", "53"},
- }, map[string]int{
- "dir1": 3,
- }},
-
- {`localhost
- dir1 {
- nested {
- foo
- }
- }
- dir2 foo bar`, false, []address{
- {"localhost", "localhost.", "53"},
- }, map[string]int{
- "dir1": 7,
- "dir2": 3,
- }},
-
- {``, false, []address{}, map[string]int{}},
-
- {`localhost
- dir1 arg1
- import import_test1.txt`, false, []address{
- {"localhost", "localhost.", "53"},
- }, map[string]int{
- "dir1": 2,
- "dir2": 3,
- "dir3": 1,
- }},
-
- {`import import_test2.txt`, false, []address{
- {"host1", "host1.", "53"},
- }, map[string]int{
- "dir1": 1,
- "dir2": 2,
- }},
-
- {`import import_test1.txt import_test2.txt`, true, []address{}, map[string]int{}},
-
- {`import not_found.txt`, true, []address{}, map[string]int{}},
-
- {`""`, false, []address{}, map[string]int{}},
-
- {``, false, []address{}, map[string]int{}},
- } {
- result, err := testParseOne(test.input)
-
- if test.shouldErr && err == nil {
- t.Errorf("Test %d: Expected an error, but didn't get one", i)
- }
- if !test.shouldErr && err != nil {
- t.Errorf("Test %d: Expected no error, but got: %v", i, err)
- }
-
- if len(result.Addresses) != len(test.addresses) {
- t.Errorf("Test %d: Expected %d addresses, got %d",
- i, len(test.addresses), len(result.Addresses))
- continue
- }
- for j, addr := range result.Addresses {
- if addr.Host != test.addresses[j].Host {
- t.Errorf("Test %d, address %d: Expected host to be '%s', but was '%s'",
- i, j, test.addresses[j].Host, addr.Host)
- }
- if addr.Port != test.addresses[j].Port {
- t.Errorf("Test %d, address %d: Expected port to be '%s', but was '%s'",
- i, j, test.addresses[j].Port, addr.Port)
- }
- }
-
- if len(result.Tokens) != len(test.tokens) {
- t.Errorf("Test %d: Expected %d directives, had %d",
- i, len(test.tokens), len(result.Tokens))
- continue
- }
- for directive, tokens := range result.Tokens {
- if len(tokens) != test.tokens[directive] {
- t.Errorf("Test %d, directive '%s': Expected %d tokens, counted %d",
- i, directive, test.tokens[directive], len(tokens))
- continue
- }
- }
- }
-}
-
-func TestParseAll(t *testing.T) {
- setupParseTests()
-
- for i, test := range []struct {
- input string
- shouldErr bool
- addresses [][]address // addresses per server block, in order
- }{
- {`localhost`, false, [][]address{
- {{"localhost", "localhost.", "53"}},
- }},
-
- {`localhost:1234`, false, [][]address{
- {{"localhost:1234", "localhost.", "1234"}},
- }},
-
- {`localhost:1234 {
- }
- localhost:2015 {
- }`, false, [][]address{
- {{"localhost:1234", "localhost.", "1234"}},
- {{"localhost:2015", "localhost.", "2015"}},
- }},
-
- {`localhost:1234, host2`, false, [][]address{
- {{"localhost:1234", "localhost.", "1234"}, {"host2", "host2.", "53"}},
- }},
-
- {`localhost:1234, http://host2,`, true, [][]address{}},
-
- {`import import_glob*.txt`, false, [][]address{
- {{"glob0.host0", "glob0.host0.", "53"}},
- {{"glob0.host1", "glob0.host1.", "53"}},
- {{"glob1.host0", "glob1.host0.", "53"}},
- {{"glob2.host0", "glob2.host0.", "53"}},
- }},
- } {
- p := testParser(test.input)
- blocks, err := p.parseAll()
-
- if test.shouldErr && err == nil {
- t.Errorf("Test %d: Expected an error, but didn't get one", i)
- }
- if !test.shouldErr && err != nil {
- t.Errorf("Test %d: Expected no error, but got: %v", i, err)
- }
-
- if len(blocks) != len(test.addresses) {
- t.Errorf("Test %d: Expected %d server blocks, got %d",
- i, len(test.addresses), len(blocks))
- continue
- }
- for j, block := range blocks {
- if len(block.Addresses) != len(test.addresses[j]) {
- t.Errorf("Test %d: Expected %d addresses in block %d, got %d",
- i, len(test.addresses[j]), j, len(block.Addresses))
- continue
- }
- for k, addr := range block.Addresses {
- if addr.Host != test.addresses[j][k].Host {
- t.Errorf("Test %d, block %d, address %d: Expected host to be '%s', but was '%s'",
- i, j, k, test.addresses[j][k].Host, addr.Host)
- }
- if addr.Port != test.addresses[j][k].Port {
- t.Errorf("Test %d, block %d, address %d: Expected port to be '%s', but was '%s'",
- i, j, k, test.addresses[j][k].Port, addr.Port)
- }
- }
- }
- }
-}
-
-func TestEnvironmentReplacement(t *testing.T) {
- setupParseTests()
-
- os.Setenv("PORT", "8080")
- os.Setenv("ADDRESS", "servername.com")
- os.Setenv("FOOBAR", "foobar")
-
- // basic test; unix-style env vars
- p := testParser(`{$ADDRESS}`)
- blocks, _ := p.parseAll()
- if actual, expected := blocks[0].Addresses[0].Host, "servername.com."; expected != actual {
- t.Errorf("Expected host to be '%s' but was '%s'", expected, actual)
- }
-
- // multiple vars per token
- p = testParser(`{$ADDRESS}:{$PORT}`)
- blocks, _ = p.parseAll()
- if actual, expected := blocks[0].Addresses[0].Host, "servername.com."; expected != actual {
- t.Errorf("Expected host to be '%s' but was '%s'", expected, actual)
- }
- if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual {
- t.Errorf("Expected port to be '%s' but was '%s'", expected, actual)
- }
-
- // windows-style var and unix style in same token
- p = testParser(`{%ADDRESS%}:{$PORT}`)
- blocks, _ = p.parseAll()
- if actual, expected := blocks[0].Addresses[0].Host, "servername.com."; expected != actual {
- t.Errorf("Expected host to be '%s' but was '%s'", expected, actual)
- }
- if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual {
- t.Errorf("Expected port to be '%s' but was '%s'", expected, actual)
- }
-
- // reverse order
- p = testParser(`{$ADDRESS}:{%PORT%}`)
- blocks, _ = p.parseAll()
- if actual, expected := blocks[0].Addresses[0].Host, "servername.com."; expected != actual {
- t.Errorf("Expected host to be '%s' but was '%s'", expected, actual)
- }
- if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual {
- t.Errorf("Expected port to be '%s' but was '%s'", expected, actual)
- }
-
- // env var in server block body as argument
- p = testParser(":{%PORT%}\ndir1 {$FOOBAR}")
- blocks, _ = p.parseAll()
- if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual {
- t.Errorf("Expected port to be '%s' but was '%s'", expected, actual)
- }
- if actual, expected := blocks[0].Tokens["dir1"][1].text, "foobar"; expected != actual {
- t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual)
- }
-
- // combined windows env vars in argument
- p = testParser(":{%PORT%}\ndir1 {%ADDRESS%}/{%FOOBAR%}")
- blocks, _ = p.parseAll()
- if actual, expected := blocks[0].Tokens["dir1"][1].text, "servername.com/foobar"; expected != actual {
- t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual)
- }
-
- // malformed env var (windows)
- p = testParser(":1234\ndir1 {%ADDRESS}")
- blocks, _ = p.parseAll()
- if actual, expected := blocks[0].Tokens["dir1"][1].text, "{%ADDRESS}"; expected != actual {
- t.Errorf("Expected host to be '%s' but was '%s'", expected, actual)
- }
-
- // malformed (non-existent) env var (unix)
- p = testParser(`:{$PORT$}`)
- blocks, _ = p.parseAll()
- if actual, expected := blocks[0].Addresses[0].Port, "53"; expected != actual {
- t.Errorf("Expected port to be '%s' but was '%s'", expected, actual)
- }
-
- // in quoted field
- p = testParser(":1234\ndir1 \"Test {$FOOBAR} test\"")
- blocks, _ = p.parseAll()
- if actual, expected := blocks[0].Tokens["dir1"][1].text, "Test foobar test"; expected != actual {
- t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual)
- }
-}
-
-func setupParseTests() {
- // Set up some bogus directives for testing
- ValidDirectives = map[string]struct{}{
- "dir1": {},
- "dir2": {},
- "dir3": {},
- }
-}
-
-func testParser(input string) parser {
- buf := strings.NewReader(input)
- p := parser{Dispenser: NewDispenser("Test", buf), checkDirectives: true}
- return p
-}
diff --git a/core/restart.go b/core/restart.go
deleted file mode 100644
index aa77e152d..000000000
--- a/core/restart.go
+++ /dev/null
@@ -1,177 +0,0 @@
-// +build !windows
-
-package core
-
-import (
- "bytes"
- "encoding/gob"
- "errors"
- "io/ioutil"
- "log"
- "net"
- "os"
- "os/exec"
- "path"
- "sync/atomic"
-
- "github.com/miekg/coredns/core/https"
-)
-
-func init() {
- gob.Register(CorefileInput{})
-}
-
-// Restart restarts the entire application; gracefully with zero
-// downtime if on a POSIX-compatible system, or forcefully if on
-// Windows but with imperceptibly-short downtime.
-//
-// The restarted application will use newCorefile as its input
-// configuration. If newCorefile is nil, the current (existing)
-// Corefile configuration will be used.
-//
-// Note: The process must exist in the same place on the disk in
-// order for this to work. Thus, multiple graceful restarts don't
-// work if executing with `go run`, since the binary is cleaned up
-// when `go run` sees the initial parent process exit.
-func Restart(newCorefile Input) error {
- log.Println("[INFO] Restarting")
-
- if newCorefile == nil {
- corefileMu.Lock()
- newCorefile = corefile
- corefileMu.Unlock()
- }
-
- // Get certificates for any new hosts in the new Corefile without causing downtime
- err := getCertsForNewCorefile(newCorefile)
- if err != nil {
- return errors.New("TLS preload: " + err.Error())
- }
-
- if len(os.Args) == 0 { // this should never happen, but...
- os.Args = []string{""}
- }
-
- // Tell the child that it's a restart
- os.Setenv("COREDNS_RESTART", "true")
-
- // Prepare our payload to the child process
- crfileGob := corefileGob{
- ListenerFds: make(map[string]uintptr),
- Corefile: newCorefile,
- OnDemandTLSCertsIssued: atomic.LoadInt32(https.OnDemandIssuedCount),
- }
-
- // Prepare a pipe to the fork's stdin so it can get the Corefile
- rpipe, wpipe, err := os.Pipe()
- if err != nil {
- return err
- }
-
- // Prepare a pipe that the child process will use to communicate
- // its success with us by sending > 0 bytes
- sigrpipe, sigwpipe, err := os.Pipe()
- if err != nil {
- return err
- }
-
- // Pass along relevant file descriptors to child process; ordering
- // is very important since we rely on these being in certain positions.
- extraFiles := []*os.File{sigwpipe} // fd 3
-
- // Add file descriptors of all the sockets
- serversMu.Lock()
- j := 0
- for _, s := range servers {
- extraFiles = append(extraFiles, s.ListenerFd())
- extraFiles = append(extraFiles, s.PacketConnFd())
- // So this will be 0 1 2 3 TCP UDP TCP UDP ... etc.
- crfileGob.ListenerFds["tcp"+s.Addr] = uintptr(4 + j) // 4 fds come before any of the listeners
- crfileGob.ListenerFds["udp"+s.Addr] = uintptr(4 + j + 1) // add udp after that
- j += 2
- }
- serversMu.Unlock()
-
- // Set up the command
- cmd := exec.Command(os.Args[0], os.Args[1:]...)
- cmd.Stdin = rpipe // fd 0
- cmd.Stdout = os.Stdout // fd 1
- cmd.Stderr = os.Stderr // fd 2
- cmd.ExtraFiles = extraFiles
-
- // Spawn the child process
- err = cmd.Start()
- if err != nil {
- return err
- }
-
- // Immediately close our dup'ed fds and the write end of our signal pipe
- for _, f := range extraFiles {
- f.Close()
- }
-
- // Feed Corefile to the child
- err = gob.NewEncoder(wpipe).Encode(crfileGob)
- if err != nil {
- return err
- }
- wpipe.Close()
-
- // Run all shutdown functions for the middleware, if child start fails, restart them all...
- executeShutdownCallbacks("SIGUSR1")
-
- // Determine whether child startup succeeded
- answer, readErr := ioutil.ReadAll(sigrpipe)
- if answer == nil || len(answer) == 0 {
- cmdErr := cmd.Wait() // get exit status
- log.Printf("[ERROR] Restart: child failed to initialize (%v) - changes not applied", cmdErr)
- if readErr != nil {
- log.Printf("[ERROR] Restart: additionally, error communicating with child process: %v", readErr)
- }
- // re-call all startup functions.
- // TODO(miek): this needs to be tested, somehow.
- executeStartupCallbacks("SIGUSR1")
- return errIncompleteRestart
- }
-
- // Looks like child is successful; we can exit gracefully.
- return Stop()
-}
-
-func getCertsForNewCorefile(newCorefile Input) error {
- // parse the new corefile only up to (and including) TLS
- // so we can know what we need to get certs for.
- configs, _, _, err := loadConfigsUpToIncludingTLS(path.Base(newCorefile.Path()), bytes.NewReader(newCorefile.Body()))
- if err != nil {
- return errors.New("loading Corefile: " + err.Error())
- }
-
- // first mark the configs that are qualified for managed TLS
- https.MarkQualified(configs)
-
- // since we group by bind address to obtain certs, we must call
- // EnableTLS to make sure the port is set properly first
- // (can ignore error since we aren't actually using the certs)
- https.EnableTLS(configs, false)
-
- // find out if we can let the acme package start its own challenge listener
- // on port 80
- var proxyACME bool
- serversMu.Lock()
- for _, s := range servers {
- _, port, _ := net.SplitHostPort(s.Addr)
- if port == "80" {
- proxyACME = true
- break
- }
- }
- serversMu.Unlock()
-
- // place certs on the disk
- err = https.ObtainCerts(configs, false, proxyACME)
- if err != nil {
- return errors.New("obtaining certs: " + err.Error())
- }
-
- return nil
-}
diff --git a/core/restart_windows.go b/core/restart_windows.go
deleted file mode 100644
index 8a0805a19..000000000
--- a/core/restart_windows.go
+++ /dev/null
@@ -1,31 +0,0 @@
-package core
-
-import "log"
-
-// Restart restarts CoreDNS forcefully using newCorefile,
-// or, if nil, the current/existing Corefile is reused.
-func Restart(newCorefile Input) error {
- log.Println("[INFO] Restarting")
-
- if newCorefile == nil {
- corefileMu.Lock()
- newCorefile = corefile
- corefileMu.Unlock()
- }
-
- wg.Add(1) // barrier so Wait() doesn't unblock
-
- err := Stop()
- if err != nil {
- return err
- }
-
- err = Start(newCorefile)
- if err != nil {
- return err
- }
-
- wg.Done() // take down our barrier
-
- return nil
-}
diff --git a/core/setup/bindhost.go b/core/setup/bindhost.go
deleted file mode 100644
index a3c07e5eb..000000000
--- a/core/setup/bindhost.go
+++ /dev/null
@@ -1,13 +0,0 @@
-package setup
-
-import "github.com/miekg/coredns/middleware"
-
-// BindHost sets the host to bind the listener to.
-func BindHost(c *Controller) (middleware.Middleware, error) {
- for c.Next() {
- if !c.Args(&c.BindHost) {
- return nil, c.ArgErr()
- }
- }
- return nil, nil
-}
diff --git a/core/setup/chaos.go b/core/setup/chaos.go
deleted file mode 100644
index 13103adf4..000000000
--- a/core/setup/chaos.go
+++ /dev/null
@@ -1,45 +0,0 @@
-package setup
-
-import (
- "github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/chaos"
-)
-
-// Chaos configures a new Chaos middleware instance.
-func Chaos(c *Controller) (middleware.Middleware, error) {
- version, authors, err := chaosParse(c)
- if err != nil {
- return nil, err
- }
-
- return func(next middleware.Handler) middleware.Handler {
- return chaos.Chaos{
- Next: next,
- Version: version,
- Authors: authors,
- }
- }, nil
-}
-
-func chaosParse(c *Controller) (string, map[string]bool, error) {
- version := ""
- authors := make(map[string]bool)
-
- for c.Next() {
- args := c.RemainingArgs()
- if len(args) == 0 {
- return defaultVersion, nil, nil
- }
- if len(args) == 1 {
- return args[0], nil, nil
- }
- version = args[0]
- for _, a := range args[1:] {
- authors[a] = true
- }
- return version, authors, nil
- }
- return version, authors, nil
-}
-
-const defaultVersion = "CoreDNS"
diff --git a/core/setup/controller.go b/core/setup/controller.go
deleted file mode 100644
index 7f8da6721..000000000
--- a/core/setup/controller.go
+++ /dev/null
@@ -1,85 +0,0 @@
-package setup
-
-import (
- "fmt"
- "strings"
-
- "golang.org/x/net/context"
-
- "github.com/miekg/coredns/core/parse"
- "github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/server"
- "github.com/miekg/dns"
-)
-
-// Controller is given to the setup function of middlewares which
-// gives them access to be able to read tokens and set config. Each
-// virtualhost gets their own server config and dispenser.
-type Controller struct {
- *server.Config
- parse.Dispenser
-
- // OncePerServerBlock is a function that executes f
- // exactly once per server block, no matter how many
- // hosts are associated with it. If it is the first
- // time, the function f is executed immediately
- // (not deferred) and may return an error which is
- // returned by OncePerServerBlock.
- OncePerServerBlock func(f func() error) error
-
- // ServerBlockIndex is the 0-based index of the
- // server block as it appeared in the input.
- ServerBlockIndex int
-
- // ServerBlockHostIndex is the 0-based index of this
- // host as it appeared in the input at the head of the
- // server block.
- ServerBlockHostIndex int
-
- // ServerBlockHosts is a list of hosts that are
- // associated with this server block. All these
- // hosts, consequently, share the same tokens.
- ServerBlockHosts []string
-
- // ServerBlockStorage is used by a directive's
- // setup function to persist state between all
- // the hosts on a server block.
- ServerBlockStorage interface{}
-}
-
-// NewTestController creates a new *Controller for
-// the input specified, with a filename of "Testfile".
-// The Config is bare, consisting only of a Root of cwd.
-//
-// Used primarily for testing but needs to be exported so
-// add-ons can use this as a convenience. Does not initialize
-// the server-block-related fields.
-func NewTestController(input string) *Controller {
- return &Controller{
- Config: &server.Config{
- Root: ".",
- },
- Dispenser: parse.NewDispenser("Testfile", strings.NewReader(input)),
- OncePerServerBlock: func(f func() error) error {
- return f()
- },
- }
-}
-
-// EmptyNext is a no-op function that can be passed into
-// middleware.Middleware functions so that the assignment
-// to the Next field of the Handler can be tested.
-//
-// Used primarily for testing but needs to be exported so
-// add-ons can use this as a convenience.
-var EmptyNext = middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
- return 0, nil
-})
-
-// SameNext does a pointer comparison between next1 and next2.
-//
-// Used primarily for testing but needs to be exported so
-// add-ons can use this as a convenience.
-func SameNext(next1, next2 middleware.Handler) bool {
- return fmt.Sprintf("%v", next1) == fmt.Sprintf("%v", next2)
-}
diff --git a/core/setup/health.go b/core/setup/health.go
deleted file mode 100644
index 542cb3260..000000000
--- a/core/setup/health.go
+++ /dev/null
@@ -1,34 +0,0 @@
-package setup
-
-import (
- "github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/health"
-)
-
-func Health(c *Controller) (middleware.Middleware, error) {
- addr, err := parseHealth(c)
- if err != nil {
- return nil, err
- }
-
- h := &health.Health{Addr: addr}
- c.Startup = append(c.Startup, h.Start)
- c.Shutdown = append(c.Shutdown, h.Shutdown)
- return nil, nil
-}
-
-func parseHealth(c *Controller) (string, error) {
- addr := ""
- for c.Next() {
- args := c.RemainingArgs()
-
- switch len(args) {
- case 0:
- case 1:
- addr = args[0]
- default:
- return "", c.ArgErr()
- }
- }
- return addr, nil
-}
diff --git a/core/setup/loadbalance.go b/core/setup/loadbalance.go
deleted file mode 100644
index 4b132489b..000000000
--- a/core/setup/loadbalance.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package setup
-
-import (
- "github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/loadbalance"
-)
-
-// Loadbalance sets up the root file path of the server.
-func Loadbalance(c *Controller) (middleware.Middleware, error) {
- for c.Next() {
- // TODO(miek): block and option parsing
- }
- return func(next middleware.Handler) middleware.Handler {
- return loadbalance.RoundRobin{Next: next}
- }, nil
-}
diff --git a/core/setup/metrics.go b/core/setup/metrics.go
deleted file mode 100644
index e88d93c86..000000000
--- a/core/setup/metrics.go
+++ /dev/null
@@ -1,72 +0,0 @@
-package setup
-
-import (
- "sync"
-
- "github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/metrics"
-)
-
-const addr = "localhost:9153"
-
-var metricsOnce sync.Once
-
-func Prometheus(c *Controller) (middleware.Middleware, error) {
- m, err := parsePrometheus(c)
- if err != nil {
- return nil, err
- }
-
- metricsOnce.Do(func() {
- c.Startup = append(c.Startup, m.Start)
- c.Shutdown = append(c.Shutdown, m.Shutdown)
- })
-
- return func(next middleware.Handler) middleware.Handler {
- m.Next = next
- return m
- }, nil
-}
-
-func parsePrometheus(c *Controller) (metrics.Metrics, error) {
- var (
- met metrics.Metrics
- err error
- )
-
- for c.Next() {
- if len(met.ZoneNames) > 0 {
- return metrics.Metrics{}, c.Err("metrics: can only have one metrics module per server")
- }
- met = metrics.Metrics{ZoneNames: c.ServerBlockHosts}
- for i, _ := range met.ZoneNames {
- met.ZoneNames[i] = middleware.Host(met.ZoneNames[i]).Normalize()
- }
- args := c.RemainingArgs()
-
- switch len(args) {
- case 0:
- case 1:
- met.Addr = args[0]
- default:
- return metrics.Metrics{}, c.ArgErr()
- }
- for c.NextBlock() {
- switch c.Val() {
- case "address":
- args = c.RemainingArgs()
- if len(args) != 1 {
- return metrics.Metrics{}, c.ArgErr()
- }
- met.Addr = args[0]
- default:
- return metrics.Metrics{}, c.Errf("metrics: unknown item: %s", c.Val())
- }
-
- }
- }
- if met.Addr == "" {
- met.Addr = addr
- }
- return met, err
-}
diff --git a/core/setup/pprof.go b/core/setup/pprof.go
deleted file mode 100644
index 125d2a9ef..000000000
--- a/core/setup/pprof.go
+++ /dev/null
@@ -1,33 +0,0 @@
-package setup
-
-import (
- "sync"
-
- "github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/pprof"
-)
-
-var pprofOnce sync.Once
-
-// PProf returns a new instance of a pprof handler. It accepts no arguments or options.
-func PProf(c *Controller) (middleware.Middleware, error) {
- found := false
- for c.Next() {
- if found {
- return nil, c.Err("pprof can only be specified once")
- }
- if len(c.RemainingArgs()) != 0 {
- return nil, c.ArgErr()
- }
- if c.NextBlock() {
- return nil, c.ArgErr()
- }
- found = true
- }
- handler := &pprof.Handler{}
- pprofOnce.Do(func() {
- c.Startup = append(c.Startup, handler.Start)
- c.Shutdown = append(c.Shutdown, handler.Shutdown)
- })
- return nil, nil
-}
diff --git a/core/setup/proxy.go b/core/setup/proxy.go
deleted file mode 100644
index 6753d07ad..000000000
--- a/core/setup/proxy.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package setup
-
-import (
- "github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/proxy"
-)
-
-// Proxy configures a new Proxy middleware instance.
-func Proxy(c *Controller) (middleware.Middleware, error) {
- upstreams, err := proxy.NewStaticUpstreams(c.Dispenser)
- if err != nil {
- return nil, err
- }
- return func(next middleware.Handler) middleware.Handler {
- return proxy.Proxy{Next: next, Client: proxy.Clients(), Upstreams: upstreams}
- }, nil
-}
diff --git a/core/setup/rewrite_test.go b/core/setup/rewrite_test.go
deleted file mode 100644
index 5345c4bf6..000000000
--- a/core/setup/rewrite_test.go
+++ /dev/null
@@ -1,234 +0,0 @@
-package setup
-
-/*
-func TestRewrite(t *testing.T) {
- c := NewTestController(`rewrite /from /to`)
-
- mid, err := Rewrite(c)
- if err != nil {
- t.Errorf("Expected no errors, but got: %v", err)
- }
- if mid == nil {
- t.Fatal("Expected middleware, was nil instead")
- }
-
- handler := mid(EmptyNext)
- myHandler, ok := handler.(rewrite.Rewrite)
- if !ok {
- t.Fatalf("Expected handler to be type Rewrite, got: %#v", handler)
- }
-
- if !SameNext(myHandler.Next, EmptyNext) {
- t.Error("'Next' field of handler was not set properly")
- }
-
- if len(myHandler.Rules) != 1 {
- t.Errorf("Expected handler to have %d rule, has %d instead", 1, len(myHandler.Rules))
- }
-}
-
-func TestRewriteParse(t *testing.T) {
- simpleTests := []struct {
- input string
- shouldErr bool
- expected []rewrite.Rule
- }{
- {`rewrite /from /to`, false, []rewrite.Rule{
- rewrite.SimpleRule{From: "/from", To: "/to"},
- }},
- {`rewrite /from /to
- rewrite a b`, false, []rewrite.Rule{
- rewrite.SimpleRule{From: "/from", To: "/to"},
- rewrite.SimpleRule{From: "a", To: "b"},
- }},
- {`rewrite a`, true, []rewrite.Rule{}},
- {`rewrite`, true, []rewrite.Rule{}},
- {`rewrite a b c`, false, []rewrite.Rule{
- rewrite.SimpleRule{From: "a", To: "b c"},
- }},
- }
-
- for i, test := range simpleTests {
- c := NewTestController(test.input)
- actual, err := rewriteParse(c)
-
- if err == nil && test.shouldErr {
- t.Errorf("Test %d didn't error, but it should have", i)
- } else if err != nil && !test.shouldErr {
- t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
- } else if err != nil && test.shouldErr {
- continue
- }
-
- if len(actual) != len(test.expected) {
- t.Fatalf("Test %d expected %d rules, but got %d",
- i, len(test.expected), len(actual))
- }
-
- for j, e := range test.expected {
- actualRule := actual[j].(rewrite.SimpleRule)
- expectedRule := e.(rewrite.SimpleRule)
-
- if actualRule.From != expectedRule.From {
- t.Errorf("Test %d, rule %d: Expected From=%s, got %s",
- i, j, expectedRule.From, actualRule.From)
- }
-
- if actualRule.To != expectedRule.To {
- t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
- i, j, expectedRule.To, actualRule.To)
- }
- }
- }
-
- regexpTests := []struct {
- input string
- shouldErr bool
- expected []rewrite.Rule
- }{
- {`rewrite {
- r .*
- to /to /index.php?
- }`, false, []rewrite.Rule{
- &rewrite.ComplexRule{Base: "/", To: "/to /index.php?", Regexp: regexp.MustCompile(".*")},
- }},
- {`rewrite {
- regexp .*
- to /to
- ext / html txt
- }`, false, []rewrite.Rule{
- &rewrite.ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")},
- }},
- {`rewrite /path {
- r rr
- to /dest
- }
- rewrite / {
- regexp [a-z]+
- to /to /to2
- }
- `, false, []rewrite.Rule{
- &rewrite.ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")},
- &rewrite.ComplexRule{Base: "/", To: "/to /to2", Regexp: regexp.MustCompile("[a-z]+")},
- }},
- {`rewrite {
- r .*
- }`, true, []rewrite.Rule{
- &rewrite.ComplexRule{},
- }},
- {`rewrite {
-
- }`, true, []rewrite.Rule{
- &rewrite.ComplexRule{},
- }},
- {`rewrite /`, true, []rewrite.Rule{
- &rewrite.ComplexRule{},
- }},
- {`rewrite {
- to /to
- if {path} is a
- }`, false, []rewrite.Rule{
- &rewrite.ComplexRule{Base: "/", To: "/to", Ifs: []rewrite.If{{A: "{path}", Operator: "is", B: "a"}}},
- }},
- {`rewrite {
- status 500
- }`, true, []rewrite.Rule{
- &rewrite.ComplexRule{},
- }},
- {`rewrite {
- status 400
- }`, false, []rewrite.Rule{
- &rewrite.ComplexRule{Base: "/", Status: 400},
- }},
- {`rewrite {
- to /to
- status 400
- }`, false, []rewrite.Rule{
- &rewrite.ComplexRule{Base: "/", To: "/to", Status: 400},
- }},
- {`rewrite {
- status 399
- }`, true, []rewrite.Rule{
- &rewrite.ComplexRule{},
- }},
- {`rewrite {
- status 200
- }`, false, []rewrite.Rule{
- &rewrite.ComplexRule{Base: "/", Status: 200},
- }},
- {`rewrite {
- to /to
- status 200
- }`, false, []rewrite.Rule{
- &rewrite.ComplexRule{Base: "/", To: "/to", Status: 200},
- }},
- {`rewrite {
- status 199
- }`, true, []rewrite.Rule{
- &rewrite.ComplexRule{},
- }},
- {`rewrite {
- status 0
- }`, true, []rewrite.Rule{
- &rewrite.ComplexRule{},
- }},
- {`rewrite {
- to /to
- status 0
- }`, true, []rewrite.Rule{
- &rewrite.ComplexRule{},
- }},
- }
-
- for i, test := range regexpTests {
- c := NewTestController(test.input)
- actual, err := rewriteParse(c)
-
- if err == nil && test.shouldErr {
- t.Errorf("Test %d didn't error, but it should have", i)
- } else if err != nil && !test.shouldErr {
- t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
- } else if err != nil && test.shouldErr {
- continue
- }
-
- if len(actual) != len(test.expected) {
- t.Fatalf("Test %d expected %d rules, but got %d",
- i, len(test.expected), len(actual))
- }
-
- for j, e := range test.expected {
- actualRule := actual[j].(*rewrite.ComplexRule)
- expectedRule := e.(*rewrite.ComplexRule)
-
- if actualRule.Base != expectedRule.Base {
- t.Errorf("Test %d, rule %d: Expected Base=%s, got %s",
- i, j, expectedRule.Base, actualRule.Base)
- }
-
- if actualRule.To != expectedRule.To {
- t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
- i, j, expectedRule.To, actualRule.To)
- }
-
- if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) {
- t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v",
- i, j, expectedRule.To, actualRule.To)
- }
-
- if actualRule.Regexp != nil {
- if actualRule.String() != expectedRule.String() {
- t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
- i, j, expectedRule.String(), actualRule.String())
- }
- }
-
- if fmt.Sprint(actualRule.Ifs) != fmt.Sprint(expectedRule.Ifs) {
- t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
- i, j, fmt.Sprint(expectedRule.Ifs), fmt.Sprint(actualRule.Ifs))
- }
-
- }
- }
-}
-*/
diff --git a/core/setup/roller.go b/core/setup/roller.go
deleted file mode 100644
index fd772cc47..000000000
--- a/core/setup/roller.go
+++ /dev/null
@@ -1,40 +0,0 @@
-package setup
-
-import (
- "strconv"
-
- "github.com/miekg/coredns/middleware"
-)
-
-func parseRoller(c *Controller) (*middleware.LogRoller, error) {
- var size, age, keep int
- // This is kind of a hack to support nested blocks:
- // As we are already in a block: either log or errors,
- // c.nesting > 0 but, as soon as c meets a }, it thinks
- // the block is over and return false for c.NextBlock.
- for c.NextBlock() {
- what := c.Val()
- if !c.NextArg() {
- return nil, c.ArgErr()
- }
- value := c.Val()
- var err error
- switch what {
- case "size":
- size, err = strconv.Atoi(value)
- case "age":
- age, err = strconv.Atoi(value)
- case "keep":
- keep, err = strconv.Atoi(value)
- }
- if err != nil {
- return nil, err
- }
- }
- return &middleware.LogRoller{
- MaxSize: size,
- MaxAge: age,
- MaxBackups: keep,
- LocalTime: true,
- }, nil
-}
diff --git a/core/setup/root.go b/core/setup/root.go
deleted file mode 100644
index 0fce5f170..000000000
--- a/core/setup/root.go
+++ /dev/null
@@ -1,32 +0,0 @@
-package setup
-
-import (
- "log"
- "os"
-
- "github.com/miekg/coredns/middleware"
-)
-
-// Root sets up the root file path of the server.
-func Root(c *Controller) (middleware.Middleware, error) {
- for c.Next() {
- if !c.NextArg() {
- return nil, c.ArgErr()
- }
- c.Root = c.Val()
- }
-
- // Check if root path exists
- _, err := os.Stat(c.Root)
- if err != nil {
- if os.IsNotExist(err) {
- // Allow this, because the folder might appear later.
- // But make sure the user knows!
- log.Printf("[WARNING] Root path does not exist: %s", c.Root)
- } else {
- return nil, c.Errf("Unable to access root path '%s': %v", c.Root, err)
- }
- }
-
- return nil, nil
-}
diff --git a/core/setup/root_test.go b/core/setup/root_test.go
deleted file mode 100644
index 8b38e6d04..000000000
--- a/core/setup/root_test.go
+++ /dev/null
@@ -1,108 +0,0 @@
-package setup
-
-import (
- "fmt"
- "io/ioutil"
- "os"
- "path/filepath"
- "strings"
- "testing"
-)
-
-func TestRoot(t *testing.T) {
-
- // Predefined error substrings
- parseErrContent := "Parse error:"
- unableToAccessErrContent := "Unable to access root path"
-
- existingDirPath, err := getTempDirPath()
- if err != nil {
- t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err)
- }
-
- nonExistingDir := filepath.Join(existingDirPath, "highly_unlikely_to_exist_dir")
-
- existingFile, err := ioutil.TempFile("", "root_test")
- if err != nil {
- t.Fatalf("BeforeTest: Failed to create temp file for testing! Error was: %v", err)
- }
- defer func() {
- existingFile.Close()
- os.Remove(existingFile.Name())
- }()
-
- inaccessiblePath := getInaccessiblePath(existingFile.Name())
-
- tests := []struct {
- input string
- shouldErr bool
- expectedRoot string // expected root, set to the controller. Empty for negative cases.
- expectedErrContent string // substring from the expected error. Empty for positive cases.
- }{
- // positive
- {
- fmt.Sprintf(`root %s`, nonExistingDir), false, nonExistingDir, "",
- },
- {
- fmt.Sprintf(`root %s`, existingDirPath), false, existingDirPath, "",
- },
- // negative
- {
- `root `, true, "", parseErrContent,
- },
- {
- fmt.Sprintf(`root %s`, inaccessiblePath), true, "", unableToAccessErrContent,
- },
- {
- fmt.Sprintf(`root {
- %s
- }`, existingDirPath), true, "", parseErrContent,
- },
- }
-
- for i, test := range tests {
- c := NewTestController(test.input)
- mid, err := Root(c)
-
- if test.shouldErr && err == nil {
- t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input)
- }
-
- if err != nil {
- if !test.shouldErr {
- t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err)
- }
-
- if !strings.Contains(err.Error(), test.expectedErrContent) {
- t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input)
- }
- }
-
- // the Root method always returns a nil middleware
- if mid != nil {
- t.Errorf("Middware, returned from Root() was not nil: %v", mid)
- }
-
- // check c.Root only if we are in a positive test.
- if !test.shouldErr && test.expectedRoot != c.Root {
- t.Errorf("Root not correctly set for input %s. Expected: %s, actual: %s", test.input, test.expectedRoot, c.Root)
- }
- }
-}
-
-// getTempDirPath returnes the path to the system temp directory. If it does not exists - an error is returned.
-func getTempDirPath() (string, error) {
- tempDir := os.TempDir()
-
- _, err := os.Stat(tempDir)
- if err != nil {
- return "", err
- }
-
- return tempDir, nil
-}
-
-func getInaccessiblePath(file string) string {
- // null byte in filename is not allowed on Windows AND unix
- return filepath.Join("C:", "file\x00name")
-}
diff --git a/core/setup/startupshutdown.go b/core/setup/startupshutdown.go
deleted file mode 100644
index 1cf2c62e0..000000000
--- a/core/setup/startupshutdown.go
+++ /dev/null
@@ -1,64 +0,0 @@
-package setup
-
-import (
- "os"
- "os/exec"
- "strings"
-
- "github.com/miekg/coredns/middleware"
-)
-
-// Startup registers a startup callback to execute during server start.
-func Startup(c *Controller) (middleware.Middleware, error) {
- return nil, registerCallback(c, &c.FirstStartup)
-}
-
-// Shutdown registers a shutdown callback to execute during process exit.
-func Shutdown(c *Controller) (middleware.Middleware, error) {
- return nil, registerCallback(c, &c.Shutdown)
-}
-
-// registerCallback registers a callback function to execute by
-// using c to parse the line. It appends the callback function
-// to the list of callback functions passed in by reference.
-func registerCallback(c *Controller, list *[]func() error) error {
- var funcs []func() error
-
- for c.Next() {
- args := c.RemainingArgs()
- if len(args) == 0 {
- return c.ArgErr()
- }
-
- nonblock := false
- if len(args) > 1 && args[len(args)-1] == "&" {
- // Run command in background; non-blocking
- nonblock = true
- args = args[:len(args)-1]
- }
-
- command, args, err := middleware.SplitCommandAndArgs(strings.Join(args, " "))
- if err != nil {
- return c.Err(err.Error())
- }
-
- fn := func() error {
- cmd := exec.Command(command, args...)
- cmd.Stdin = os.Stdin
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
-
- if nonblock {
- return cmd.Start()
- }
- return cmd.Run()
- }
-
- funcs = append(funcs, fn)
- }
-
- return c.OncePerServerBlock(func() error {
- *list = append(*list, funcs...)
- return nil
- })
-}
diff --git a/core/setup/startupshutdown_test.go b/core/setup/startupshutdown_test.go
deleted file mode 100644
index 871a64214..000000000
--- a/core/setup/startupshutdown_test.go
+++ /dev/null
@@ -1,59 +0,0 @@
-package setup
-
-import (
- "os"
- "path/filepath"
- "strconv"
- "testing"
- "time"
-)
-
-// The Startup function's tests are symmetrical to Shutdown tests,
-// because the Startup and Shutdown functions share virtually the
-// same functionality
-func TestStartup(t *testing.T) {
- tempDirPath, err := getTempDirPath()
- if err != nil {
- t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err)
- }
-
- testDir := filepath.Join(tempDirPath, "temp_dir_for_testing_startupshutdown")
- defer func() {
- // clean up after non-blocking startup function quits
- time.Sleep(500 * time.Millisecond)
- os.RemoveAll(testDir)
- }()
- osSenitiveTestDir := filepath.FromSlash(testDir)
- os.RemoveAll(osSenitiveTestDir) // start with a clean slate
-
- tests := []struct {
- input string
- shouldExecutionErr bool
- shouldRemoveErr bool
- }{
- // test case #0 tests proper functionality blocking commands
- {"startup mkdir " + osSenitiveTestDir, false, false},
-
- // test case #1 tests proper functionality of non-blocking commands
- {"startup mkdir " + osSenitiveTestDir + " &", false, true},
-
- // test case #2 tests handling of non-existent commands
- {"startup " + strconv.Itoa(int(time.Now().UnixNano())), true, true},
- }
-
- for i, test := range tests {
- c := NewTestController(test.input)
- _, err = Startup(c)
- if err != nil {
- t.Errorf("Expected no errors, got: %v", err)
- }
- err = c.FirstStartup[0]()
- if err != nil && !test.shouldExecutionErr {
- t.Errorf("Test %d recieved an error of:\n%v", i, err)
- }
- err = os.Remove(osSenitiveTestDir)
- if err != nil && !test.shouldRemoveErr {
- t.Errorf("Test %d recieved an error of:\n%v", i, err)
- }
- }
-}
diff --git a/core/setup/testdata/blog/first_post.md b/core/setup/testdata/blog/first_post.md
deleted file mode 100644
index f26583b75..000000000
--- a/core/setup/testdata/blog/first_post.md
+++ /dev/null
@@ -1 +0,0 @@
-# Test h1
diff --git a/core/setup/testdata/header.html b/core/setup/testdata/header.html
deleted file mode 100644
index 9c96e0e37..000000000
--- a/core/setup/testdata/header.html
+++ /dev/null
@@ -1 +0,0 @@
-Header title
diff --git a/core/setup/testdata/tpl_with_include.html b/core/setup/testdata/tpl_with_include.html
deleted file mode 100644
index 95eeae0c8..000000000
--- a/core/setup/testdata/tpl_with_include.html
+++ /dev/null
@@ -1,10 +0,0 @@
-
-
-
-{{.Doc.title}}
-
-
-{{.Include "header.html"}}
-{{.Doc.body}}
-
-
diff --git a/core/sigtrap.go b/core/sigtrap.go
deleted file mode 100644
index f40dd971a..000000000
--- a/core/sigtrap.go
+++ /dev/null
@@ -1,93 +0,0 @@
-package core
-
-import (
- "log"
- "os"
- "os/signal"
- "sync"
-
- "github.com/miekg/coredns/server"
-)
-
-// TrapSignals create signal handlers for all applicable signals for this
-// system. If your Go program uses signals, this is a rather invasive
-// function; best to implement them yourself in that case. Signals are not
-// required for the caddy package to function properly, but this is a
-// convenient way to allow the user to control this package of your program.
-func TrapSignals() {
- trapSignalsCrossPlatform()
- trapSignalsPosix()
-}
-
-// trapSignalsCrossPlatform captures SIGINT, which triggers forceful
-// shutdown that executes shutdown callbacks first. A second interrupt
-// signal will exit the process immediately.
-func trapSignalsCrossPlatform() {
- go func() {
- shutdown := make(chan os.Signal, 1)
- signal.Notify(shutdown, os.Interrupt)
-
- for i := 0; true; i++ {
- <-shutdown
-
- if i > 0 {
- log.Println("[INFO] SIGINT: Force quit")
- if PidFile != "" {
- os.Remove(PidFile)
- }
- os.Exit(1)
- }
-
- log.Println("[INFO] SIGINT: Shutting down")
-
- if PidFile != "" {
- os.Remove(PidFile)
- }
-
- go os.Exit(executeShutdownCallbacks("SIGINT"))
- }
- }()
-}
-
-// executeShutdownCallbacks executes the shutdown callbacks as initiated
-// by signame. It logs any errors and returns the recommended exit status.
-// This function is idempotent; subsequent invocations always return 0.
-func executeShutdownCallbacks(signame string) (exitCode int) {
- shutdownCallbacksOnce.Do(func() {
- serversMu.Lock()
- errs := server.ShutdownCallbacks(servers)
- serversMu.Unlock()
-
- if len(errs) > 0 {
- for _, err := range errs {
- log.Printf("[ERROR] %s shutdown: %v", signame, err)
- }
- exitCode = 1
- }
- })
- return
-}
-
-// executeStartupCallbacks executes the startup callbacks as initiated
-// by signame. This is used when on restart when the child failed to start and
-// all middleware executed their shutdown functions
-func executeStartupCallbacks(signame string) (exitCode int) {
- startupCallbacksOnce.Do(func() {
- serversMu.Lock()
- errs := server.StartupCallbacks(servers)
- serversMu.Unlock()
-
- if len(errs) > 0 {
- for _, err := range errs {
- log.Printf("[ERROR] %s shutdown: %v", signame, err)
- }
- exitCode = 1
- }
- })
- return
-}
-
-var (
- shutdownCallbacksOnce sync.Once
- startupCallbacksOnce sync.Once
-)
diff --git a/core/sigtrap_posix.go b/core/sigtrap_posix.go
deleted file mode 100644
index d120dd52b..000000000
--- a/core/sigtrap_posix.go
+++ /dev/null
@@ -1,79 +0,0 @@
-// +build !windows
-
-package core
-
-import (
- "io/ioutil"
- "log"
- "os"
- "os/signal"
- "syscall"
-)
-
-// trapSignalsPosix captures POSIX-only signals.
-func trapSignalsPosix() {
- go func() {
- sigchan := make(chan os.Signal, 1)
- signal.Notify(sigchan, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGUSR1)
-
- for sig := range sigchan {
- switch sig {
- case syscall.SIGTERM:
- log.Println("[INFO] SIGTERM: Terminating process")
- if PidFile != "" {
- os.Remove(PidFile)
- }
- os.Exit(0)
-
- case syscall.SIGQUIT:
- log.Println("[INFO] SIGQUIT: Shutting down")
- exitCode := executeShutdownCallbacks("SIGQUIT")
- err := Stop()
- if err != nil {
- log.Printf("[ERROR] SIGQUIT stop: %v", err)
- exitCode = 1
- }
- if PidFile != "" {
- os.Remove(PidFile)
- }
- os.Exit(exitCode)
-
- case syscall.SIGHUP:
- log.Println("[INFO] SIGHUP: Hanging up")
- err := Stop()
- if err != nil {
- log.Printf("[ERROR] SIGHUP stop: %v", err)
- }
-
- case syscall.SIGUSR1:
- log.Println("[INFO] SIGUSR1: Reloading")
-
- var updatedCorefile Input
-
- corefileMu.Lock()
- if corefile == nil {
- // Hmm, did spawing process forget to close stdin? Anyhow, this is unusual.
- log.Println("[ERROR] SIGUSR1: no Corefile to reload (was stdin left open?)")
- corefileMu.Unlock()
- continue
- }
- if corefile.IsFile() {
- body, err := ioutil.ReadFile(corefile.Path())
- if err == nil {
- updatedCorefile = CorefileInput{
- Filepath: corefile.Path(),
- Contents: body,
- RealFile: true,
- }
- }
- }
- corefileMu.Unlock()
-
- err := Restart(updatedCorefile)
- if err != nil {
- log.Printf("[ERROR] SIGUSR1: %v", err)
- }
- }
- }
- }()
-}
diff --git a/core/sigtrap_windows.go b/core/sigtrap_windows.go
deleted file mode 100644
index 59132cee4..000000000
--- a/core/sigtrap_windows.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package core
-
-func trapSignalsPosix() {}
diff --git a/coredns.go b/coredns.go
new file mode 100644
index 000000000..e7e54dc60
--- /dev/null
+++ b/coredns.go
@@ -0,0 +1,18 @@
+package main
+
+import (
+ "flag"
+
+ "github.com/mholt/caddy"
+ "github.com/mholt/caddy/caddy/caddymain"
+)
+
+//go:generate go run plugin_generate.go
+
+func main() {
+ // Set some flags/options specific for CoreDNS.
+ flag.Set("type", "dns")
+ caddy.DefaultConfigFile = "Corefile"
+
+ caddymain.Run()
+}
diff --git a/main.go b/main.go
deleted file mode 100644
index 1900fc60c..000000000
--- a/main.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package main
-
-import "github.com/miekg/coredns/core/coremain"
-
-func main() {
- coremain.Run()
-}
diff --git a/middleware/bind/README.md b/middleware/bind/README.md
new file mode 100644
index 000000000..ad23b6153
--- /dev/null
+++ b/middleware/bind/README.md
@@ -0,0 +1,21 @@
+# bind
+
+bind overrides the host to which the server should bind. Normally, the listener binds to the
+wildcard host. However, you may force the listener to bind to another IP instead. This
+directive accepts only an address, not a port.
+
+## Syntax
+
+~~~ txt
+bind address
+~~~
+
+address is the IP address to bind to.
+
+## Examples
+
+To make your socket accessible only to that machine, bind to IP 127.0.0.1 (localhost):
+
+~~~ txt
+bind 127.0.0.1
+~~~
diff --git a/middleware/bind/bind.go b/middleware/bind/bind.go
new file mode 100644
index 000000000..ac27c993b
--- /dev/null
+++ b/middleware/bind/bind.go
@@ -0,0 +1,10 @@
+package bind
+
+import "github.com/mholt/caddy"
+
+func init() {
+ caddy.RegisterPlugin("bind", caddy.Plugin{
+ ServerType: "dns",
+ Action: setupBind,
+ })
+}
diff --git a/middleware/bind/bind_test.go b/middleware/bind/bind_test.go
new file mode 100644
index 000000000..d61741a02
--- /dev/null
+++ b/middleware/bind/bind_test.go
@@ -0,0 +1,30 @@
+package bind
+
+import (
+ "testing"
+
+ "github.com/miekg/coredns/core/dnsserver"
+
+ "github.com/mholt/caddy"
+)
+
+func TestSetupBind(t *testing.T) {
+ c := caddy.NewTestController("dns", `bind 1.2.3.4`)
+ err := setupBind(c)
+ if err != nil {
+ t.Fatalf("Expected no errors, but got: %v", err)
+ }
+
+ cfg := dnsserver.GetConfig(c)
+ if got, want := cfg.ListenHost, "1.2.3.4"; got != want {
+ t.Errorf("Expected the config's ListenHost to be %s, was %s", want, got)
+ }
+}
+
+func TestBindAddress(t *testing.T) {
+ c := caddy.NewTestController("dns", `bind 1.2.3.bla`)
+ err := setupBind(c)
+ if err == nil {
+ t.Fatalf("Expected errors, but got none")
+ }
+}
diff --git a/middleware/bind/setup.go b/middleware/bind/setup.go
new file mode 100644
index 000000000..c08098b5d
--- /dev/null
+++ b/middleware/bind/setup.go
@@ -0,0 +1,23 @@
+package bind
+
+import (
+ "fmt"
+ "net"
+
+ "github.com/miekg/coredns/core/dnsserver"
+
+ "github.com/mholt/caddy"
+)
+
+func setupBind(c *caddy.Controller) error {
+ config := dnsserver.GetConfig(c)
+ for c.Next() {
+ if !c.Args(&config.ListenHost) {
+ return c.ArgErr()
+ }
+ }
+ if net.ParseIP(config.ListenHost) == nil {
+ return fmt.Errorf("not a valid IP address: %s", config.ListenHost)
+ }
+ return nil
+}
diff --git a/core/setup/cache.go b/middleware/cache/setup.go
similarity index 54%
rename from core/setup/cache.go
rename to middleware/cache/setup.go
index f5a1cf0d9..ab7f423f2 100644
--- a/core/setup/cache.go
+++ b/middleware/cache/setup.go
@@ -1,33 +1,46 @@
-package setup
+package cache
import (
"strconv"
+ "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/cache"
+
+ "github.com/mholt/caddy"
)
-// Cache sets up the root file path of the server.
-func Cache(c *Controller) (middleware.Middleware, error) {
- ttl, zones, err := cacheParse(c)
- if err != nil {
- return nil, err
- }
- return func(next middleware.Handler) middleware.Handler {
- return cache.NewCache(ttl, zones, next)
- }, nil
+func init() {
+ caddy.RegisterPlugin("cache", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
}
-func cacheParse(c *Controller) (int, []string, error) {
+// Cache sets up the root file path of the server.
+func setup(c *caddy.Controller) error {
+ ttl, zones, err := cacheParse(c)
+ if err != nil {
+ return err
+ }
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ return NewCache(ttl, zones, next)
+ })
+
+ return nil
+}
+
+func cacheParse(c *caddy.Controller) (int, []string, error) {
var (
- err error
- ttl int
+ err error
+ ttl int
+ origins []string
)
for c.Next() {
if c.Val() == "cache" {
// cache [ttl] [zones..]
- origins := c.ServerBlockHosts
+ origins = make([]string, len(c.ServerBlockKeys))
+ copy(origins, c.ServerBlockKeys)
args := c.RemainingArgs()
if len(args) > 0 {
origins = args
@@ -38,7 +51,7 @@ func cacheParse(c *Controller) (int, []string, error) {
origins = origins[1:]
if len(origins) == 0 {
// There was *only* the ttl, revert back to server block
- origins = c.ServerBlockHosts
+ copy(origins, c.ServerBlockKeys)
}
}
}
diff --git a/middleware/chaos/setup.go b/middleware/chaos/setup.go
new file mode 100644
index 000000000..8bdb3053e
--- /dev/null
+++ b/middleware/chaos/setup.go
@@ -0,0 +1,50 @@
+package chaos
+
+import (
+ "github.com/miekg/coredns/core/dnsserver"
+
+ "github.com/mholt/caddy"
+)
+
+func init() {
+ caddy.RegisterPlugin("chaos", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
+ version, authors, err := chaosParse(c)
+ if err != nil {
+ return err
+ }
+
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ return Chaos{Next: next, Version: version, Authors: authors}
+ })
+
+ return nil
+}
+
+func chaosParse(c *caddy.Controller) (string, map[string]bool, error) {
+ version := ""
+ authors := make(map[string]bool)
+
+ for c.Next() {
+ args := c.RemainingArgs()
+ if len(args) == 0 {
+ return defaultVersion, nil, nil
+ }
+ if len(args) == 1 {
+ return args[0], nil, nil
+ }
+ version = args[0]
+ for _, a := range args[1:] {
+ authors[a] = true
+ }
+ return version, authors, nil
+ }
+ return version, authors, nil
+}
+
+const defaultVersion = "CoreDNS"
diff --git a/core/setup/chaos_test.go b/middleware/chaos/setup_test.go
similarity index 92%
rename from core/setup/chaos_test.go
rename to middleware/chaos/setup_test.go
index 8431cecef..c1741cdf6 100644
--- a/core/setup/chaos_test.go
+++ b/middleware/chaos/setup_test.go
@@ -1,12 +1,14 @@
-package setup
+package chaos
import (
"fmt"
"strings"
"testing"
+
+ "github.com/mholt/caddy"
)
-func TestChaos(t *testing.T) {
+func TestSetupChaos(t *testing.T) {
tests := []struct {
input string
shouldErr bool
@@ -32,7 +34,7 @@ func TestChaos(t *testing.T) {
}
for i, test := range tests {
- c := NewTestController(test.input)
+ c := caddy.NewTestController("dns", test.input)
version, authors, err := chaosParse(c)
if test.shouldErr && err == nil {
diff --git a/middleware/dnssec/cache_test.go b/middleware/dnssec/cache_test.go
index 0039586d5..3062f99b0 100644
--- a/middleware/dnssec/cache_test.go
+++ b/middleware/dnssec/cache_test.go
@@ -22,7 +22,7 @@ func TestCacheSet(t *testing.T) {
m := testMsg()
state := middleware.State{Req: m}
k := key(m.Answer) // calculate *before* we add the sig
- d := NewDnssec([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil)
+ d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil)
m = d.Sign(state, "miek.nl.", time.Now().UTC())
_, ok := d.get(k)
diff --git a/middleware/dnssec/dnssec.go b/middleware/dnssec/dnssec.go
index 4907f1621..f517bfe2c 100644
--- a/middleware/dnssec/dnssec.go
+++ b/middleware/dnssec/dnssec.go
@@ -11,14 +11,15 @@ import (
)
type Dnssec struct {
- Next middleware.Handler
+ Next middleware.Handler
+
zones []string
keys []*DNSKEY
inflight *singleflight.Group
cache *gcache.Cache
}
-func NewDnssec(zones []string, keys []*DNSKEY, next middleware.Handler) Dnssec {
+func New(zones []string, keys []*DNSKEY, next middleware.Handler) Dnssec {
return Dnssec{Next: next,
zones: zones,
keys: keys,
diff --git a/middleware/dnssec/dnssec_test.go b/middleware/dnssec/dnssec_test.go
index 49b0d5d3a..10f731325 100644
--- a/middleware/dnssec/dnssec_test.go
+++ b/middleware/dnssec/dnssec_test.go
@@ -69,7 +69,7 @@ func TestSigningDifferentZone(t *testing.T) {
m := testMsgEx()
state := middleware.State{Req: m}
- d := NewDnssec([]string{"example.org."}, []*DNSKEY{key}, nil)
+ d := New([]string{"example.org."}, []*DNSKEY{key}, nil)
m = d.Sign(state, "example.org.", time.Now().UTC())
if !section(m.Answer, 1) {
t.Errorf("answer section should have 1 sig")
@@ -158,7 +158,7 @@ func testDelegationMsg() *dns.Msg {
func newDnssec(t *testing.T, zones []string) (Dnssec, func(), func()) {
k, rm1, rm2 := newKey(t)
- d := NewDnssec(zones, []*DNSKEY{k}, nil)
+ d := New(zones, []*DNSKEY{k}, nil)
return d, rm1, rm2
}
diff --git a/middleware/dnssec/handler_test.go b/middleware/dnssec/handler_test.go
index 6f537b90e..f7cb7e680 100644
--- a/middleware/dnssec/handler_test.go
+++ b/middleware/dnssec/handler_test.go
@@ -77,7 +77,7 @@ func TestLookupZone(t *testing.T) {
dnskey, rm1, rm2 := newKey(t)
defer rm1()
defer rm2()
- dh := NewDnssec([]string{"miek.nl."}, []*DNSKEY{dnskey}, fm)
+ dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, fm)
ctx := context.TODO()
for _, tc := range dnsTestCases {
@@ -115,7 +115,7 @@ func TestLookupDNSKEY(t *testing.T) {
dnskey, rm1, rm2 := newKey(t)
defer rm1()
defer rm2()
- dh := NewDnssec([]string{"miek.nl."}, []*DNSKEY{dnskey}, test.ErrorHandler())
+ dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, test.ErrorHandler())
ctx := context.TODO()
for _, tc := range dnssecTestCases {
diff --git a/core/setup/dnssec.go b/middleware/dnssec/setup.go
similarity index 60%
rename from core/setup/dnssec.go
rename to middleware/dnssec/setup.go
index 39f34b66f..999f85bf8 100644
--- a/core/setup/dnssec.go
+++ b/middleware/dnssec/setup.go
@@ -1,32 +1,43 @@
-package setup
+package dnssec
import (
"strings"
+ "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/dnssec"
+
+ "github.com/mholt/caddy"
)
-// Dnssec sets up the dnssec middleware.
-func Dnssec(c *Controller) (middleware.Middleware, error) {
- zones, keys, err := dnssecParse(c)
- if err != nil {
- return nil, err
- }
-
- return func(next middleware.Handler) middleware.Handler {
- return dnssec.NewDnssec(zones, keys, next)
- }, nil
+func init() {
+ caddy.RegisterPlugin("dnssec", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
}
-func dnssecParse(c *Controller) ([]string, []*dnssec.DNSKEY, error) {
+func setup(c *caddy.Controller) error {
+ zones, keys, err := dnssecParse(c)
+ if err != nil {
+ return err
+ }
+
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ return New(zones, keys, next)
+ })
+
+ return nil
+}
+
+func dnssecParse(c *caddy.Controller) ([]string, []*DNSKEY, error) {
zones := []string{}
- keys := []*dnssec.DNSKEY{}
+ keys := []*DNSKEY{}
for c.Next() {
if c.Val() == "dnssec" {
// dnssec [zones...]
- zones = c.ServerBlockHosts
+ zones = make([]string, len(c.ServerBlockKeys))
+ copy(zones, c.ServerBlockKeys)
args := c.RemainingArgs()
if len(args) > 0 {
zones = args
@@ -47,8 +58,8 @@ func dnssecParse(c *Controller) ([]string, []*dnssec.DNSKEY, error) {
return zones, keys, nil
}
-func keyParse(c *Controller) ([]*dnssec.DNSKEY, error) {
- keys := []*dnssec.DNSKEY{}
+func keyParse(c *caddy.Controller) ([]*DNSKEY, error) {
+ keys := []*DNSKEY{}
what := c.Val()
if !c.NextArg() {
@@ -68,7 +79,7 @@ func keyParse(c *Controller) ([]*dnssec.DNSKEY, error) {
if strings.HasSuffix(k, ".private") {
base = k[:len(k)-8]
}
- k, err := dnssec.ParseKeyFile(base+".key", base+".private")
+ k, err := ParseKeyFile(base+".key", base+".private")
if err != nil {
return nil, err
}
diff --git a/core/setup/dnssec_test.go b/middleware/dnssec/setup_test.go
similarity index 90%
rename from core/setup/dnssec_test.go
rename to middleware/dnssec/setup_test.go
index 364a363bd..9dbeb77fd 100644
--- a/core/setup/dnssec_test.go
+++ b/middleware/dnssec/setup_test.go
@@ -1,11 +1,13 @@
-package setup
+package dnssec
import (
"strings"
"testing"
+
+ "github.com/mholt/caddy"
)
-func TestDnssec(t *testing.T) {
+func TestSetupDnssec(t *testing.T) {
tests := []struct {
input string
shouldErr bool
@@ -22,7 +24,7 @@ func TestDnssec(t *testing.T) {
}
for i, test := range tests {
- c := NewTestController(test.input)
+ c := caddy.NewTestController("dns", test.input)
zones, keys, err := dnssecParse(c)
if test.shouldErr && err == nil {
diff --git a/core/setup/errors.go b/middleware/errors/setup.go
similarity index 78%
rename from core/setup/errors.go
rename to middleware/errors/setup.go
index bf6b56f87..e1c77373d 100644
--- a/core/setup/errors.go
+++ b/middleware/errors/setup.go
@@ -1,21 +1,28 @@
-package setup
+package errors
import (
"io"
"log"
"os"
+ "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/errors"
"github.com/hashicorp/go-syslog"
+ "github.com/mholt/caddy"
)
-// Errors configures a new errors middleware instance.
-func Errors(c *Controller) (middleware.Middleware, error) {
+func init() {
+ caddy.RegisterPlugin("errors", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
handler, err := errorsParse(c)
if err != nil {
- return nil, err
+ return err
}
var writer io.Writer
@@ -30,7 +37,7 @@ func Errors(c *Controller) (middleware.Middleware, error) {
case "syslog":
writer, err = gsyslog.NewLogger(gsyslog.LOG_ERR, "LOCAL0", "coredns")
if err != nil {
- return nil, err
+ return err
}
default:
if handler.LogFile == "" {
@@ -41,7 +48,7 @@ func Errors(c *Controller) (middleware.Middleware, error) {
var file *os.File
file, err = os.OpenFile(handler.LogFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
- return nil, err
+ return err
}
if handler.LogRoller != nil {
file.Close()
@@ -55,14 +62,16 @@ func Errors(c *Controller) (middleware.Middleware, error) {
}
handler.Log = log.New(writer, "", 0)
- return func(next middleware.Handler) middleware.Handler {
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
handler.Next = next
return handler
- }, nil
+ })
+
+ return nil
}
-func errorsParse(c *Controller) (errors.ErrorHandler, error) {
- handler := errors.ErrorHandler{}
+func errorsParse(c *caddy.Controller) (ErrorHandler, error) {
+ handler := ErrorHandler{}
optionalBlock := func() (bool, error) {
var hadBlock bool
@@ -84,7 +93,7 @@ func errorsParse(c *Controller) (errors.ErrorHandler, error) {
if c.NextArg() {
if c.Val() == "{" {
c.IncrNest()
- logRoller, err := parseRoller(c)
+ logRoller, err := middleware.ParseRoller(c)
if err != nil {
return hadBlock, err
}
diff --git a/core/setup/errors_test.go b/middleware/errors/setup_test.go
similarity index 74%
rename from core/setup/errors_test.go
rename to middleware/errors/setup_test.go
index 42f625f92..6e5a85d08 100644
--- a/core/setup/errors_test.go
+++ b/middleware/errors/setup_test.go
@@ -1,62 +1,33 @@
-package setup
+package errors
import (
"testing"
+ "github.com/mholt/caddy"
"github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/errors"
)
-func TestErrors(t *testing.T) {
- c := NewTestController(`errors`)
- mid, err := Errors(c)
-
- if err != nil {
- t.Errorf("Expected no errors, got: %v", err)
- }
-
- if mid == nil {
- t.Fatal("Expected middleware, was nil instead")
- }
-
- handler := mid(EmptyNext)
- myHandler, ok := handler.(errors.ErrorHandler)
- if !ok {
- t.Fatalf("Expected handler to be type ErrorHandler, got: %#v", handler)
- }
-
- if myHandler.LogFile != "" {
- t.Errorf("Expected '%s' as the default LogFile", "")
- }
- if myHandler.LogRoller != nil {
- t.Errorf("Expected LogRoller to be nil, got: %v", *myHandler.LogRoller)
- }
- if !SameNext(myHandler.Next, EmptyNext) {
- t.Error("'Next' field of handler was not set properly")
- }
-}
-
func TestErrorsParse(t *testing.T) {
tests := []struct {
inputErrorsRules string
shouldErr bool
- expectedErrorHandler errors.ErrorHandler
+ expectedErrorHandler ErrorHandler
}{
- {`errors`, false, errors.ErrorHandler{
+ {`errors`, false, ErrorHandler{
LogFile: "",
}},
- {`errors errors.txt`, false, errors.ErrorHandler{
+ {`errors errors.txt`, false, ErrorHandler{
LogFile: "errors.txt",
}},
- {`errors visible`, false, errors.ErrorHandler{
+ {`errors visible`, false, ErrorHandler{
LogFile: "",
Debug: true,
}},
- {`errors { log visible }`, false, errors.ErrorHandler{
+ {`errors { log visible }`, false, ErrorHandler{
LogFile: "",
Debug: true,
}},
- {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, errors.ErrorHandler{
+ {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, ErrorHandler{
LogFile: "errors.txt",
LogRoller: &middleware.LogRoller{
MaxSize: 2,
@@ -70,7 +41,7 @@ func TestErrorsParse(t *testing.T) {
age 11
keep 5
}
-}`, false, errors.ErrorHandler{
+}`, false, ErrorHandler{
LogFile: "errors.txt",
LogRoller: &middleware.LogRoller{
MaxSize: 3,
@@ -81,7 +52,7 @@ func TestErrorsParse(t *testing.T) {
}},
}
for i, test := range tests {
- c := NewTestController(test.inputErrorsRules)
+ c := caddy.NewTestController("dns", test.inputErrorsRules)
actualErrorsRule, err := errorsParse(c)
if err == nil && test.shouldErr {
diff --git a/middleware/etcd/cname_test.go b/middleware/etcd/cname_test.go
index 4a00c05c2..ee341b7b6 100644
--- a/middleware/etcd/cname_test.go
+++ b/middleware/etcd/cname_test.go
@@ -16,6 +16,8 @@ import (
// Check the ordering of returned cname.
func TestCnameLookup(t *testing.T) {
+ etc := newEtcdMiddleware()
+
for _, serv := range servicesCname {
set(t, etc, serv.Key, 0, serv)
defer delete(t, etc, serv.Key)
diff --git a/middleware/etcd/debug_test.go b/middleware/etcd/debug_test.go
index 91796816f..82de9fe1f 100644
--- a/middleware/etcd/debug_test.go
+++ b/middleware/etcd/debug_test.go
@@ -30,12 +30,13 @@ func TestIsDebug(t *testing.T) {
}
func TestDebugLookup(t *testing.T) {
+ etc := newEtcdMiddleware()
+ etc.Debug = true
+
for _, serv := range servicesDebug {
set(t, etc, serv.Key, 0, serv)
defer delete(t, etc, serv.Key)
}
- etc.Debug = true
- defer func() { etc.Debug = false }()
for _, tc := range dnsTestCasesDebug {
m := tc.Msg()
@@ -69,6 +70,8 @@ func TestDebugLookup(t *testing.T) {
}
func TestDebugLookupFalse(t *testing.T) {
+ etc := newEtcdMiddleware()
+
for _, serv := range servicesDebug {
set(t, etc, serv.Key, 0, serv)
defer delete(t, etc, serv.Key)
diff --git a/middleware/etcd/group_test.go b/middleware/etcd/group_test.go
index f5283e3ba..7a2808d45 100644
--- a/middleware/etcd/group_test.go
+++ b/middleware/etcd/group_test.go
@@ -14,6 +14,8 @@ import (
)
func TestGroupLookup(t *testing.T) {
+ etc := newEtcdMiddleware()
+
for _, serv := range servicesGroup {
set(t, etc, serv.Key, 0, serv)
defer delete(t, etc, serv.Key)
diff --git a/middleware/etcd/multi_test.go b/middleware/etcd/multi_test.go
index 19e6ca7a3..f4b59f50b 100644
--- a/middleware/etcd/multi_test.go
+++ b/middleware/etcd/multi_test.go
@@ -14,10 +14,9 @@ import (
)
func TestMultiLookup(t *testing.T) {
+ etc := newEtcdMiddleware()
etc.Zones = []string{"skydns.test.", "miek.nl."}
- defer func() { etc.Zones = []string{"skydns.test.", "skydns_extra.test.", "in-addr.arpa."} }()
etc.Next = test.ErrorHandler()
- defer func() { etc.Next = nil }()
for _, serv := range servicesMulti {
set(t, etc, serv.Key, 0, serv)
diff --git a/middleware/etcd/other_test.go b/middleware/etcd/other_test.go
index 34971c6f1..ff37d27d2 100644
--- a/middleware/etcd/other_test.go
+++ b/middleware/etcd/other_test.go
@@ -18,6 +18,8 @@ import (
)
func TestOtherLookup(t *testing.T) {
+ etc := newEtcdMiddleware()
+
for _, serv := range servicesOther {
set(t, etc, serv.Key, 0, serv)
defer delete(t, etc, serv.Key)
diff --git a/middleware/etcd/proxy_lookup_test.go b/middleware/etcd/proxy_lookup_test.go
index 9a31eee24..5e0999fb0 100644
--- a/middleware/etcd/proxy_lookup_test.go
+++ b/middleware/etcd/proxy_lookup_test.go
@@ -15,18 +15,15 @@ import (
)
func TestProxyLookupFailDebug(t *testing.T) {
+ etc := newEtcdMiddleware()
+ etc.Proxy = proxy.New([]string{"127.0.0.1:154"})
+ etc.Debug = true
+
for _, serv := range servicesProxy {
set(t, etc, serv.Key, 0, serv)
defer delete(t, etc, serv.Key)
}
- prxy := etc.Proxy
- etc.Proxy = proxy.New([]string{"127.0.0.1:154"})
- defer func() { etc.Proxy = prxy }()
-
- etc.Debug = true
- defer func() { etc.Debug = false }()
-
for _, tc := range dnsTestCasesProxy {
m := tc.Msg()
diff --git a/core/setup/etcd.go b/middleware/etcd/setup.go
similarity index 79%
rename from core/setup/etcd.go
rename to middleware/etcd/setup.go
index b90297abd..dc1dddb0e 100644
--- a/core/setup/etcd.go
+++ b/middleware/etcd/setup.go
@@ -1,4 +1,4 @@
-package setup
+package etcd
import (
"crypto/tls"
@@ -8,39 +8,46 @@ import (
"net/http"
"time"
+ "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/etcd"
"github.com/miekg/coredns/middleware/proxy"
"github.com/miekg/coredns/singleflight"
etcdc "github.com/coreos/etcd/client"
+ "github.com/mholt/caddy"
"golang.org/x/net/context"
)
-const defaultEndpoint = "http://localhost:2379"
+func init() {
+ caddy.RegisterPlugin("etcd", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
-// Etcd sets up the etcd middleware.
-func Etcd(c *Controller) (middleware.Middleware, error) {
- etcd, stubzones, err := etcdParse(c)
+func setup(c *caddy.Controller) error {
+ e, stubzones, err := etcdParse(c)
if err != nil {
- return nil, err
+ return err
}
if stubzones {
- c.Startup = append(c.Startup, func() error {
- etcd.UpdateStubZones()
+ c.OnStartup(func() error {
+ e.UpdateStubZones()
return nil
})
}
- return func(next middleware.Handler) middleware.Handler {
- etcd.Next = next
- return etcd
- }, nil
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ e.Next = next
+ return e
+ })
+
+ return nil
}
-func etcdParse(c *Controller) (*etcd.Etcd, bool, error) {
+func etcdParse(c *caddy.Controller) (*Etcd, bool, error) {
stub := make(map[string]proxy.Proxy)
- etc := etcd.Etcd{
+ etc := Etcd{
Proxy: proxy.New([]string{"8.8.8.8:53", "8.8.4.4:53"}),
PathPrefix: "skydns",
Ctx: context.Background(),
@@ -60,7 +67,8 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) {
etc.Client = client
etc.Zones = c.RemainingArgs()
if len(etc.Zones) == 0 {
- etc.Zones = c.ServerBlockHosts
+ etc.Zones = make([]string, len(c.ServerBlockKeys))
+ copy(etc.Zones, c.ServerBlockKeys)
}
middleware.Zones(etc.Zones).FullyQualify()
if c.NextBlock() {
@@ -72,19 +80,19 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) {
etc.Debug = true
case "path":
if !c.NextArg() {
- return &etcd.Etcd{}, false, c.ArgErr()
+ return &Etcd{}, false, c.ArgErr()
}
etc.PathPrefix = c.Val()
case "endpoint":
args := c.RemainingArgs()
if len(args) == 0 {
- return &etcd.Etcd{}, false, c.ArgErr()
+ return &Etcd{}, false, c.ArgErr()
}
endpoints = args
case "upstream":
args := c.RemainingArgs()
if len(args) == 0 {
- return &etcd.Etcd{}, false, c.ArgErr()
+ return &Etcd{}, false, c.ArgErr()
}
for i := 0; i < len(args); i++ {
h, p, e := net.SplitHostPort(args[i])
@@ -97,7 +105,7 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) {
case "tls": // cert key cacertfile
args := c.RemainingArgs()
if len(args) != 3 {
- return &etcd.Etcd{}, false, c.ArgErr()
+ return &Etcd{}, false, c.ArgErr()
}
tlsCertFile, tlsKeyFile, tlsCAcertFile = args[0], args[1], args[2]
}
@@ -109,19 +117,19 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) {
etc.Debug = true
case "path":
if !c.NextArg() {
- return &etcd.Etcd{}, false, c.ArgErr()
+ return &Etcd{}, false, c.ArgErr()
}
etc.PathPrefix = c.Val()
case "endpoint":
args := c.RemainingArgs()
if len(args) == 0 {
- return &etcd.Etcd{}, false, c.ArgErr()
+ return &Etcd{}, false, c.ArgErr()
}
endpoints = args
case "upstream":
args := c.RemainingArgs()
if len(args) == 0 {
- return &etcd.Etcd{}, false, c.ArgErr()
+ return &Etcd{}, false, c.ArgErr()
}
for i := 0; i < len(args); i++ {
h, p, e := net.SplitHostPort(args[i])
@@ -133,7 +141,7 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) {
case "tls": // cert key cacertfile
args := c.RemainingArgs()
if len(args) != 3 {
- return &etcd.Etcd{}, false, c.ArgErr()
+ return &Etcd{}, false, c.ArgErr()
}
tlsCertFile, tlsKeyFile, tlsCAcertFile = args[0], args[1], args[2]
}
@@ -141,13 +149,13 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) {
}
client, err := newEtcdClient(endpoints, tlsCertFile, tlsKeyFile, tlsCAcertFile)
if err != nil {
- return &etcd.Etcd{}, false, err
+ return &Etcd{}, false, err
}
etc.Client = client
return &etc, stubzones, nil
}
}
- return &etcd.Etcd{}, false, nil
+ return &Etcd{}, false, nil
}
func newEtcdClient(endpoints []string, tlsCert, tlsKey, tlsCACert string) (etcdc.KeysAPI, error) {
@@ -195,3 +203,5 @@ func newHTTPSTransport(tlsCertFile, tlsKeyFile, tlsCACertFile string) etcdc.Canc
return tr
}
+
+const defaultEndpoint = "http://localhost:2379"
diff --git a/middleware/etcd/setup_test.go b/middleware/etcd/setup_test.go
index 799b4a1bb..b522345d2 100644
--- a/middleware/etcd/setup_test.go
+++ b/middleware/etcd/setup_test.go
@@ -19,20 +19,19 @@ import (
"golang.org/x/net/context"
)
-var (
- etc *Etcd
- client etcdc.KeysAPI
- ctxt context.Context
-)
-
func init() {
ctxt, _ = context.WithTimeout(context.Background(), etcdTimeout)
+}
+
+// etc *Etcd
+func newEtcdMiddleware() *Etcd {
+ ctxt, _ = context.WithTimeout(context.Background(), etcdTimeout)
etcdCfg := etcdc.Config{
Endpoints: []string{"http://localhost:2379"},
}
cli, _ := etcdc.New(etcdCfg)
- etc = &Etcd{
+ return &Etcd{
Proxy: proxy.New([]string{"8.8.8.8:53"}),
PathPrefix: "skydns",
Ctx: context.Background(),
@@ -57,10 +56,12 @@ func delete(t *testing.T, e *Etcd, k string) {
}
func TestLookup(t *testing.T) {
+ etc := newEtcdMiddleware()
for _, serv := range services {
set(t, etc, serv.Key, 0, serv)
defer delete(t, etc, serv.Key)
}
+
for _, tc := range dnsTestCases {
m := tc.Msg()
@@ -91,3 +92,5 @@ func TestLookup(t *testing.T) {
}
}
}
+
+var ctxt context.Context
diff --git a/middleware/etcd/stub_test.go b/middleware/etcd/stub_test.go
index 1dc0901c0..b5a101dad 100644
--- a/middleware/etcd/stub_test.go
+++ b/middleware/etcd/stub_test.go
@@ -41,13 +41,14 @@ func TestStubLookup(t *testing.T) {
exampleNetStub := &msg.Service{Host: host, Port: port, Key: "a.example.net.stub.dns.skydns.test."}
servicesStub = append(servicesStub, exampleNetStub)
+ etc := newEtcdMiddleware()
+
for _, serv := range servicesStub {
set(t, etc, serv.Key, 0, serv)
defer delete(t, etc, serv.Key)
}
etc.updateStubZones()
- defer func() { etc.Stubmap = nil }()
for _, tc := range dnsTestCasesStub {
m := tc.Msg()
diff --git a/core/setup/file.go b/middleware/file/setup.go
similarity index 66%
rename from core/setup/file.go
rename to middleware/file/setup.go
index a0b90c3ca..8b44650ee 100644
--- a/core/setup/file.go
+++ b/middleware/file/setup.go
@@ -1,24 +1,32 @@
-package setup
+package file
import (
"fmt"
"net"
"os"
+ "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/file"
+
+ "github.com/mholt/caddy"
)
-// File sets up the file middleware.
-func File(c *Controller) (middleware.Middleware, error) {
+func init() {
+ caddy.RegisterPlugin("file", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
zones, err := fileParse(c)
if err != nil {
- return nil, err
+ return err
}
// Add startup functions to notify the master(s).
for _, n := range zones.Names {
- c.Startup = append(c.Startup, func() error {
+ c.OnStartup(func() error {
zones.Z[n].StartupOnce.Do(func() {
if len(zones.Z[n].TransferTo) > 0 {
zones.Z[n].Notify()
@@ -29,24 +37,28 @@ func File(c *Controller) (middleware.Middleware, error) {
})
}
- return func(next middleware.Handler) middleware.Handler {
- return file.File{Next: next, Zones: zones}
- }, nil
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ return File{Next: next, Zones: zones}
+ })
+ return nil
}
-func fileParse(c *Controller) (file.Zones, error) {
- z := make(map[string]*file.Zone)
+func fileParse(c *caddy.Controller) (Zones, error) {
+ z := make(map[string]*Zone)
names := []string{}
+ origins := []string{}
+
for c.Next() {
if c.Val() == "file" {
// file db.file [zones...]
if !c.NextArg() {
- return file.Zones{}, c.ArgErr()
+ return Zones{}, c.ArgErr()
}
fileName := c.Val()
- origins := c.ServerBlockHosts
+ origins = make([]string, len(c.ServerBlockKeys))
+ copy(origins, c.ServerBlockKeys)
args := c.RemainingArgs()
if len(args) > 0 {
origins = args
@@ -55,25 +67,25 @@ func fileParse(c *Controller) (file.Zones, error) {
reader, err := os.Open(fileName)
if err != nil {
// bail out
- return file.Zones{}, err
+ return Zones{}, err
}
for i, _ := range origins {
origins[i] = middleware.Host(origins[i]).Normalize()
- zone, err := file.Parse(reader, origins[i], fileName)
+ zone, err := Parse(reader, origins[i], fileName)
if err == nil {
z[origins[i]] = zone
} else {
- return file.Zones{}, err
+ return Zones{}, err
}
names = append(names, origins[i])
}
noReload := false
for c.NextBlock() {
- t, _, e := transferParse(c)
+ t, _, e := TransferParse(c)
if e != nil {
- return file.Zones{}, e
+ return Zones{}, e
}
switch c.Val() {
case "no_reload":
@@ -89,11 +101,12 @@ func fileParse(c *Controller) (file.Zones, error) {
}
}
}
- return file.Zones{Z: z, Names: names}, nil
+ return Zones{Z: z, Names: names}, nil
}
-// transferParse parses transfer statements: 'transfer to [address...]'.
-func transferParse(c *Controller) (tos, froms []string, err error) {
+// TransferParse parses transfer statements: 'transfer to [address...]'.
+// Exported so secondary can use this as well.
+func TransferParse(c *caddy.Controller) (tos, froms []string, err error) {
what := c.Val()
if !c.NextArg() {
return nil, nil, c.ArgErr()
diff --git a/middleware/health/health.go b/middleware/health/health.go
index 035c9ca7a..1d47e409e 100644
--- a/middleware/health/health.go
+++ b/middleware/health/health.go
@@ -12,15 +12,16 @@ var once sync.Once
type Health struct {
Addr string
- ln net.Listener
- mux *http.ServeMux
+
+ ln net.Listener
+ mux *http.ServeMux
}
func health(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, ok)
}
-func (h *Health) Start() error {
+func (h *Health) Startup() error {
if h.Addr == "" {
h.Addr = defAddr
}
diff --git a/middleware/health/setup.go b/middleware/health/setup.go
new file mode 100644
index 000000000..cf7667d17
--- /dev/null
+++ b/middleware/health/setup.go
@@ -0,0 +1,42 @@
+package health
+
+import "github.com/mholt/caddy"
+
+func init() {
+ caddy.RegisterPlugin("health", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
+ addr, err := healthParse(c)
+ if err != nil {
+ return err
+ }
+
+ health := &Health{Addr: addr}
+ c.OnStartup(health.Startup)
+ c.OnShutdown(health.Shutdown)
+
+ // Don't do AddMiddleware, as health is not *really* a middleware just a separate
+ // webserver running.
+
+ return nil
+}
+
+func healthParse(c *caddy.Controller) (string, error) {
+ addr := ""
+ for c.Next() {
+ args := c.RemainingArgs()
+
+ switch len(args) {
+ case 0:
+ case 1:
+ addr = args[0]
+ default:
+ return "", c.ArgErr()
+ }
+ }
+ return addr, nil
+}
diff --git a/middleware/kubernetes/controller.go b/middleware/kubernetes/controller.go
index 3fbea313e..5de16d61c 100644
--- a/middleware/kubernetes/controller.go
+++ b/middleware/kubernetes/controller.go
@@ -2,7 +2,6 @@ package kubernetes
import (
"fmt"
- "log"
"sync"
"time"
@@ -12,7 +11,7 @@ import (
"k8s.io/kubernetes/pkg/client/cache"
client "k8s.io/kubernetes/pkg/client/unversioned"
"k8s.io/kubernetes/pkg/controller/framework"
- "k8s.io/kubernetes/pkg/labels"
+ "k8s.io/kubernetes/pkg/labels"
"k8s.io/kubernetes/pkg/runtime"
"k8s.io/kubernetes/pkg/watch"
)
@@ -24,7 +23,7 @@ var (
type dnsController struct {
client *client.Client
- selector *labels.Selector
+ selector *labels.Selector
endpController *framework.Controller
svcController *framework.Controller
@@ -45,9 +44,9 @@ type dnsController struct {
// newDNSController creates a controller for coredns
func newdnsController(kubeClient *client.Client, resyncPeriod time.Duration, lselector *labels.Selector) *dnsController {
dns := dnsController{
- client: kubeClient,
- selector: lselector,
- stopCh: make(chan struct{}),
+ client: kubeClient,
+ selector: lselector,
+ stopCh: make(chan struct{}),
}
dns.endpLister.Store, dns.endpController = framework.NewInformer(
@@ -76,54 +75,54 @@ func newdnsController(kubeClient *client.Client, resyncPeriod time.Duration, lse
func serviceListFunc(c *client.Client, ns string, s *labels.Selector) func(api.ListOptions) (runtime.Object, error) {
return func(opts api.ListOptions) (runtime.Object, error) {
- if s != nil {
- opts.LabelSelector = *s
- }
+ if s != nil {
+ opts.LabelSelector = *s
+ }
return c.Services(ns).List(opts)
}
}
func serviceWatchFunc(c *client.Client, ns string, s *labels.Selector) func(options api.ListOptions) (watch.Interface, error) {
return func(options api.ListOptions) (watch.Interface, error) {
- if s != nil {
- options.LabelSelector = *s
- }
+ if s != nil {
+ options.LabelSelector = *s
+ }
return c.Services(ns).Watch(options)
}
}
func endpointsListFunc(c *client.Client, ns string, s *labels.Selector) func(api.ListOptions) (runtime.Object, error) {
return func(opts api.ListOptions) (runtime.Object, error) {
- if s != nil {
- opts.LabelSelector = *s
- }
+ if s != nil {
+ opts.LabelSelector = *s
+ }
return c.Endpoints(ns).List(opts)
}
}
func endpointsWatchFunc(c *client.Client, ns string, s *labels.Selector) func(options api.ListOptions) (watch.Interface, error) {
return func(options api.ListOptions) (watch.Interface, error) {
- if s != nil {
- options.LabelSelector = *s
- }
+ if s != nil {
+ options.LabelSelector = *s
+ }
return c.Endpoints(ns).Watch(options)
}
}
func namespaceListFunc(c *client.Client, s *labels.Selector) func(api.ListOptions) (runtime.Object, error) {
return func(opts api.ListOptions) (runtime.Object, error) {
- if s != nil {
- opts.LabelSelector = *s
- }
+ if s != nil {
+ opts.LabelSelector = *s
+ }
return c.Namespaces().List(opts)
}
}
func namespaceWatchFunc(c *client.Client, s *labels.Selector) func(options api.ListOptions) (watch.Interface, error) {
return func(options api.ListOptions) (watch.Interface, error) {
- if s != nil {
- options.LabelSelector = *s
- }
+ if s != nil {
+ options.LabelSelector = *s
+ }
return c.Namespaces().Watch(options)
}
}
@@ -140,7 +139,6 @@ func (dns *dnsController) Stop() error {
// Only try draining the workqueue if we haven't already.
if !dns.shutdown {
close(dns.stopCh)
- log.Println("shutting down controller queues")
dns.shutdown = true
return nil
@@ -151,14 +149,10 @@ func (dns *dnsController) Stop() error {
// Run starts the controller.
func (dns *dnsController) Run() {
- log.Println("[debug] Starting k8s notification controllers")
-
go dns.endpController.Run(dns.stopCh)
go dns.svcController.Run(dns.stopCh)
go dns.nsController.Run(dns.stopCh)
-
<-dns.stopCh
- log.Println("[debug] shutting down coredns controller")
}
func (dns *dnsController) GetNamespaceList() *api.NamespaceList {
@@ -203,12 +197,12 @@ func (dns *dnsController) GetServiceInNamespace(namespace string, servicename st
svcObj, svcExists, err := dns.svcLister.Store.GetByKey(svcKey)
if err != nil {
- log.Printf("error getting service %v from the cache: %v\n", svcKey, err)
+ // TODO(...): should return err here
return nil
}
if !svcExists {
- log.Printf("service %v does not exists\n", svcKey)
+ // TODO(...): should return err here
return nil
}
diff --git a/middleware/kubernetes/handler.go b/middleware/kubernetes/handler.go
index 05dfba934..1986820d5 100644
--- a/middleware/kubernetes/handler.go
+++ b/middleware/kubernetes/handler.go
@@ -2,7 +2,6 @@ package kubernetes
import (
"fmt"
- "log"
"strings"
"github.com/miekg/coredns/middleware"
@@ -12,8 +11,6 @@ import (
)
func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
- log.Printf("[debug] here entering ServeDNS: ctx:%v dnsmsg:%v\n", ctx, r)
-
state := middleware.State{W: w, Req: r}
if state.QClass() != dns.ClassINET {
return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET")
diff --git a/middleware/kubernetes/kubernetes.go b/middleware/kubernetes/kubernetes.go
index 5e2a1bf53..59a044140 100644
--- a/middleware/kubernetes/kubernetes.go
+++ b/middleware/kubernetes/kubernetes.go
@@ -16,10 +16,10 @@ import (
"github.com/miekg/dns"
"k8s.io/kubernetes/pkg/api"
unversionedapi "k8s.io/kubernetes/pkg/api/unversioned"
- "k8s.io/kubernetes/pkg/labels"
unversionedclient "k8s.io/kubernetes/pkg/client/unversioned"
"k8s.io/kubernetes/pkg/client/unversioned/clientcmd"
clientcmdapi "k8s.io/kubernetes/pkg/client/unversioned/clientcmd/api"
+ "k8s.io/kubernetes/pkg/labels"
)
type Kubernetes struct {
@@ -32,10 +32,10 @@ type Kubernetes struct {
NameTemplate *nametemplate.NameTemplate
Namespaces []string
LabelSelector *unversionedapi.LabelSelector
- Selector *labels.Selector
+ Selector *labels.Selector
}
-func (g *Kubernetes) StartKubeCache() error {
+func (g *Kubernetes) InitKubeCache() error {
// For a custom api server or running outside a k8s cluster
// set URL in env.KUBERNETES_MASTER or set endpoint in Corefile
loadingRules := clientcmd.NewDefaultClientConfigLoadingRules()
@@ -46,7 +46,6 @@ func (g *Kubernetes) StartKubeCache() error {
clientConfig := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, overrides)
config, err := clientConfig.ClientConfig()
if err != nil {
- log.Printf("[debug] error connecting to the client: %v", err)
return err
}
kubeClient, err := unversionedclient.New(config)
@@ -58,20 +57,17 @@ func (g *Kubernetes) StartKubeCache() error {
if g.LabelSelector == nil {
log.Printf("[INFO] Kubernetes middleware configured without a label selector. No label-based filtering will be performed.")
} else {
- var selector labels.Selector
+ var selector labels.Selector
selector, err = unversionedapi.LabelSelectorAsSelector(g.LabelSelector)
- g.Selector = &selector
- if err != nil {
- log.Printf("[ERROR] Unable to create Selector for LabelSelector '%s'.Error was: %s", g.LabelSelector, err)
- return err
- }
+ g.Selector = &selector
+ if err != nil {
+ log.Printf("[ERROR] Unable to create Selector for LabelSelector '%s'.Error was: %s", g.LabelSelector, err)
+ return err
+ }
log.Printf("[INFO] Kubernetes middleware configured with the label selector '%s'. Only kubernetes objects matching this label selector will be exposed.", unversionedapi.FormatLabelSelector(g.LabelSelector))
}
- log.Printf("[debug] Starting kubernetes middleware with k8s API resync period: %s", g.ResyncPeriod)
g.APIConn = newdnsController(kubeClient, g.ResyncPeriod, g.Selector)
- go g.APIConn.Run()
-
return err
}
@@ -115,7 +111,6 @@ func (g *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) {
typeName string
)
- log.Printf("[debug] enter Records('%v', '%v')\n", name, exact)
zone, serviceSegments := g.getZoneForName(name)
// TODO: Implementation above globbed together segments for the serviceName if
@@ -137,30 +132,18 @@ func (g *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) {
serviceName = util.WildcardStar
}
- log.Printf("[debug] published namespaces: %v\n", g.Namespaces)
-
- log.Printf("[debug] exact: %v\n", exact)
- log.Printf("[debug] zone: %v\n", zone)
- log.Printf("[debug] servicename: %v\n", serviceName)
- log.Printf("[debug] namespace: %v\n", namespace)
- log.Printf("[debug] typeName: %v\n", typeName)
- log.Printf("[debug] APIconn: %v\n", g.APIConn)
-
nsWildcard := util.SymbolContainsWildcard(namespace)
serviceWildcard := util.SymbolContainsWildcard(serviceName)
// Abort if the namespace does not contain a wildcard, and namespace is not published per CoreFile
// Case where namespace contains a wildcard is handled in Get(...) method.
if (!nsWildcard) && (len(g.Namespaces) > 0) && (!util.StringInSlice(namespace, g.Namespaces)) {
- log.Printf("[debug] Namespace '%v' is not published by Corefile\n", namespace)
return nil, nil
}
- log.Printf("before g.Get(namespace, nsWildcard, serviceName, serviceWildcard): %v %v %v %v", namespace, nsWildcard, serviceName, serviceWildcard)
+ log.Printf("[debug] before g.Get(namespace, nsWildcard, serviceName, serviceWildcard): %v %v %v %v", namespace, nsWildcard, serviceName, serviceWildcard)
k8sItems, err := g.Get(namespace, nsWildcard, serviceName, serviceWildcard)
- log.Printf("[debug] k8s items: %v\n", k8sItems)
if err != nil {
- log.Printf("[ERROR] Got error while looking up ServiceItems. Error is: %v\n", err)
return nil, err
}
if k8sItems == nil {
@@ -178,7 +161,6 @@ func (g *Kubernetes) getRecordsForServiceItems(serviceItems []api.Service, value
for _, item := range serviceItems {
clusterIP := item.Spec.ClusterIP
- log.Printf("[debug] clusterIP: %v\n", clusterIP)
// Create records by constructing record name from template...
//values.Namespace = item.Metadata.Namespace
@@ -188,13 +170,11 @@ func (g *Kubernetes) getRecordsForServiceItems(serviceItems []api.Service, value
// Create records for each exposed port...
for _, p := range item.Spec.Ports {
- log.Printf("[debug] port: %v\n", p.Port)
s := msg.Service{Host: clusterIP, Port: int(p.Port)}
records = append(records, s)
}
}
- log.Printf("[debug] records from getRecordsForServiceItems(): %v\n", records)
return records
}
@@ -202,13 +182,6 @@ func (g *Kubernetes) getRecordsForServiceItems(serviceItems []api.Service, value
func (g *Kubernetes) Get(namespace string, nsWildcard bool, servicename string, serviceWildcard bool) ([]api.Service, error) {
serviceList := g.APIConn.GetServiceList()
- /* TODO: Remove?
- if err != nil {
- log.Printf("[ERROR] Getting service list produced error: %v", err)
- return nil, err
- }
- */
-
var resultItems []api.Service
for _, item := range serviceList.Items {
@@ -216,7 +189,6 @@ func (g *Kubernetes) Get(namespace string, nsWildcard bool, servicename string,
// If namespace has a wildcard, filter results against Corefile namespace list.
// (Namespaces without a wildcard were filtered before the call to this function.)
if nsWildcard && (len(g.Namespaces) > 0) && (!util.StringInSlice(item.Namespace, g.Namespaces)) {
- log.Printf("[debug] Namespace '%v' is not published by Corefile\n", item.Namespace)
continue
}
resultItems = append(resultItems, item)
diff --git a/middleware/kubernetes/nametemplate/nametemplate.go b/middleware/kubernetes/nametemplate/nametemplate.go
index 5a34ae4ad..3e1ac4bb3 100644
--- a/middleware/kubernetes/nametemplate/nametemplate.go
+++ b/middleware/kubernetes/nametemplate/nametemplate.go
@@ -2,7 +2,6 @@ package nametemplate
import (
"errors"
- "log"
"strings"
"github.com/miekg/coredns/middleware/kubernetes/util"
@@ -87,17 +86,13 @@ func (t *NameTemplate) SetTemplate(s string) error {
if !elementPositionSet {
if strings.Contains(v, "{") {
err = errors.New("Record name template contains the unknown symbol '" + v + "'")
- log.Printf("[debug] %v\n", err)
return err
- } else {
- log.Printf("[debug] Template string has static element '%v'\n", v)
}
}
}
if err == nil && !t.IsValid() {
err = errors.New("Record name template does not pass NameTemplate validation")
- log.Printf("[debug] %v\n", err)
return err
}
diff --git a/core/setup/kubernetes.go b/middleware/kubernetes/setup.go
similarity index 57%
rename from core/setup/kubernetes.go
rename to middleware/kubernetes/setup.go
index 7439a9f1b..fc3c036b8 100644
--- a/core/setup/kubernetes.go
+++ b/middleware/kubernetes/setup.go
@@ -1,70 +1,78 @@
-package setup
+package kubernetes
import (
"errors"
"fmt"
- "log"
"strings"
"time"
+ "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/kubernetes"
"github.com/miekg/coredns/middleware/kubernetes/nametemplate"
+
+ "github.com/mholt/caddy"
unversionedapi "k8s.io/kubernetes/pkg/api/unversioned"
)
-const (
- defaultNameTemplate = "{service}.{namespace}.{zone}"
- defaultResyncPeriod = 5 * time.Minute
-)
-
-// Kubernetes sets up the kubernetes middleware.
-func Kubernetes(c *Controller) (middleware.Middleware, error) {
- kubernetes, err := kubernetesParse(c)
- if err != nil {
- return nil, err
- }
-
- err = kubernetes.StartKubeCache()
- if err != nil {
- return nil, err
- }
-
- return func(next middleware.Handler) middleware.Handler {
- kubernetes.Next = next
- return kubernetes
- }, nil
+func init() {
+ caddy.RegisterPlugin("kubernetes", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
}
-func kubernetesParse(c *Controller) (kubernetes.Kubernetes, error) {
+func setup(c *caddy.Controller) error {
+ kubernetes, err := kubernetesParse(c)
+ if err != nil {
+ return err
+ }
+
+ err = kubernetes.InitKubeCache()
+ if err != nil {
+ return err
+ }
+
+ // Register KubeCache start and stop functions with Caddy
+ c.OnStartup(func() error {
+ go kubernetes.APIConn.Run()
+ return nil
+ })
+
+ c.OnShutdown(func() error {
+ return kubernetes.APIConn.Stop()
+ })
+
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ kubernetes.Next = next
+ return kubernetes
+ })
+
+ return nil
+}
+
+func kubernetesParse(c *caddy.Controller) (Kubernetes, error) {
var err error
template := defaultNameTemplate
- k8s := kubernetes.Kubernetes{
- ResyncPeriod: defaultResyncPeriod,
- }
+ k8s := Kubernetes{ResyncPeriod: defaultResyncPeriod}
k8s.NameTemplate = new(nametemplate.NameTemplate)
k8s.NameTemplate.SetTemplate(template)
- // TODO: expose resync period in Corefile
-
for c.Next() {
if c.Val() == "kubernetes" {
zones := c.RemainingArgs()
if len(zones) == 0 {
- k8s.Zones = c.ServerBlockHosts
- log.Printf("[debug] Zones(from ServerBlockHosts): %v", zones)
- } else {
- // Normalize requested zones
- k8s.Zones = kubernetes.NormalizeZoneList(zones)
+ k8s.Zones = make([]string, len(c.ServerBlockKeys))
+ copy(k8s.Zones, c.ServerBlockKeys)
}
+ k8s.Zones = NormalizeZoneList(zones)
middleware.Zones(k8s.Zones).FullyQualify()
+
if k8s.Zones == nil || len(k8s.Zones) < 1 {
err = errors.New("Zone name must be provided for kubernetes middleware.")
- log.Printf("[debug] %v\n", err)
- return kubernetes.Kubernetes{}, err
+ return Kubernetes{}, err
}
for c.NextBlock() {
@@ -75,27 +83,24 @@ func kubernetesParse(c *Controller) (kubernetes.Kubernetes, error) {
template := strings.Join(args, "")
err = k8s.NameTemplate.SetTemplate(template)
if err != nil {
- return kubernetes.Kubernetes{}, err
+ return Kubernetes{}, err
}
} else {
- log.Printf("[debug] 'template' keyword provided without any template value.")
- return kubernetes.Kubernetes{}, c.ArgErr()
+ return Kubernetes{}, c.ArgErr()
}
case "namespaces":
args := c.RemainingArgs()
if len(args) != 0 {
k8s.Namespaces = append(k8s.Namespaces, args...)
} else {
- log.Printf("[debug] 'namespaces' keyword provided without any namespace values.")
- return kubernetes.Kubernetes{}, c.ArgErr()
+ return Kubernetes{}, c.ArgErr()
}
case "endpoint":
args := c.RemainingArgs()
if len(args) != 0 {
k8s.APIEndpoint = args[0]
} else {
- log.Printf("[debug] 'endpoint' keyword provided without any endpoint url value.")
- return kubernetes.Kubernetes{}, c.ArgErr()
+ return Kubernetes{}, c.ArgErr()
}
case "resyncperiod":
args := c.RemainingArgs()
@@ -103,12 +108,10 @@ func kubernetesParse(c *Controller) (kubernetes.Kubernetes, error) {
k8s.ResyncPeriod, err = time.ParseDuration(args[0])
if err != nil {
err = errors.New(fmt.Sprintf("Unable to parse resync duration value. Value provided was '%v'. Example valid values: '15s', '5m', '1h'. Error was: %v", args[0], err))
- log.Printf("[ERROR] %v", err)
- return kubernetes.Kubernetes{}, err
+ return Kubernetes{}, err
}
} else {
- log.Printf("[debug] 'resyncperiod' keyword provided without any duration value.")
- return kubernetes.Kubernetes{}, c.ArgErr()
+ return Kubernetes{}, c.ArgErr()
}
case "labels":
args := c.RemainingArgs()
@@ -117,12 +120,10 @@ func kubernetesParse(c *Controller) (kubernetes.Kubernetes, error) {
k8s.LabelSelector, err = unversionedapi.ParseToLabelSelector(labelSelectorString)
if err != nil {
err = errors.New(fmt.Sprintf("Unable to parse label selector. Value provided was '%v'. Error was: %v", labelSelectorString, err))
- log.Printf("[ERROR] %v", err)
- return kubernetes.Kubernetes{}, err
+ return Kubernetes{}, err
}
} else {
- log.Printf("[debug] 'labels' keyword provided without any selector value.")
- return kubernetes.Kubernetes{}, c.ArgErr()
+ return Kubernetes{}, c.ArgErr()
}
}
}
@@ -130,6 +131,10 @@ func kubernetesParse(c *Controller) (kubernetes.Kubernetes, error) {
}
}
err = errors.New("Kubernetes setup called without keyword 'kubernetes' in Corefile")
- log.Printf("[ERROR] %v\n", err)
- return kubernetes.Kubernetes{}, err
+ return Kubernetes{}, err
}
+
+const (
+ defaultNameTemplate = "{service}.{namespace}.{zone}"
+ defaultResyncPeriod = 5 * time.Minute
+)
diff --git a/core/setup/kubernetes_test.go b/middleware/kubernetes/setup_test.go
similarity index 98%
rename from core/setup/kubernetes_test.go
rename to middleware/kubernetes/setup_test.go
index cf6ac9abc..6e0918a3b 100644
--- a/core/setup/kubernetes_test.go
+++ b/middleware/kubernetes/setup_test.go
@@ -1,10 +1,11 @@
-package setup
+package kubernetes
import (
"strings"
"testing"
"time"
+ "github.com/mholt/caddy"
unversionedapi "k8s.io/kubernetes/pkg/api/unversioned"
)
@@ -320,7 +321,7 @@ func TestKubernetesParse(t *testing.T) {
t.Logf("Parser test cases count: %v", len(tests))
for i, test := range tests {
- c := NewTestController(test.input)
+ c := caddy.NewTestController("dns", test.input)
k8sController, err := kubernetesParse(c)
t.Logf("setup test: %2v -- %v\n", i, test.description)
//t.Logf("controller: %v\n", k8sController)
diff --git a/middleware/loadbalance/setup.go b/middleware/loadbalance/setup.go
new file mode 100644
index 000000000..ef3f35f03
--- /dev/null
+++ b/middleware/loadbalance/setup.go
@@ -0,0 +1,25 @@
+package loadbalance
+
+import (
+ "github.com/mholt/caddy"
+ "github.com/miekg/coredns/core/dnsserver"
+)
+
+func init() {
+ caddy.RegisterPlugin("loadbalance", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
+ for c.Next() {
+ // TODO(miek): block and option parsing
+ }
+
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ return RoundRobin{Next: next}
+ })
+
+ return nil
+}
diff --git a/core/setup/log.go b/middleware/log/setup.go
similarity index 70%
rename from core/setup/log.go
rename to middleware/log/setup.go
index 1deca2565..a80e69ac3 100644
--- a/core/setup/log.go
+++ b/middleware/log/setup.go
@@ -1,26 +1,34 @@
-package setup
+package log
import (
"io"
"log"
"os"
- "github.com/hashicorp/go-syslog"
+ "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
- corednslog "github.com/miekg/coredns/middleware/log"
"github.com/miekg/coredns/server"
+
+ "github.com/hashicorp/go-syslog"
+ "github.com/mholt/caddy"
"github.com/miekg/dns"
)
-// Log sets up the logging middleware.
-func Log(c *Controller) (middleware.Middleware, error) {
+func init() {
+ caddy.RegisterPlugin("log", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
rules, err := logParse(c)
if err != nil {
- return nil, err
+ return err
}
// Open the log files for writing when the server starts
- c.Startup = append(c.Startup, func() error {
+ c.OnStartup(func() error {
for i := 0; i < len(rules); i++ {
var err error
var writer io.Writer
@@ -55,13 +63,15 @@ func Log(c *Controller) (middleware.Middleware, error) {
return nil
})
- return func(next middleware.Handler) middleware.Handler {
- return corednslog.Logger{Next: next, Rules: rules, ErrorFunc: server.DefaultErrorFunc}
- }, nil
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ return Logger{Next: next, Rules: rules, ErrorFunc: server.DefaultErrorFunc}
+ })
+
+ return nil
}
-func logParse(c *Controller) ([]corednslog.Rule, error) {
- var rules []corednslog.Rule
+func logParse(c *caddy.Controller) ([]Rule, error) {
+ var rules []Rule
for c.Next() {
args := c.RemainingArgs()
@@ -72,7 +82,7 @@ func logParse(c *Controller) ([]corednslog.Rule, error) {
if c.NextArg() {
if c.Val() == "{" {
var err error
- logRoller, err = parseRoller(c)
+ logRoller, err = middleware.ParseRoller(c)
if err != nil {
return nil, err
}
@@ -88,37 +98,37 @@ func logParse(c *Controller) ([]corednslog.Rule, error) {
}
if len(args) == 0 {
// Nothing specified; use defaults
- rules = append(rules, corednslog.Rule{
+ rules = append(rules, Rule{
NameScope: ".",
- OutputFile: corednslog.DefaultLogFilename,
- Format: corednslog.DefaultLogFormat,
+ OutputFile: DefaultLogFilename,
+ Format: DefaultLogFormat,
Roller: logRoller,
})
} else if len(args) == 1 {
// Only an output file specified
- rules = append(rules, corednslog.Rule{
+ rules = append(rules, Rule{
NameScope: ".",
OutputFile: args[0],
- Format: corednslog.DefaultLogFormat,
+ Format: DefaultLogFormat,
Roller: logRoller,
})
} else {
// Name scope, output file, and maybe a format specified
- format := corednslog.DefaultLogFormat
+ format := DefaultLogFormat
if len(args) > 2 {
switch args[2] {
case "{common}":
- format = corednslog.CommonLogFormat
+ format = CommonLogFormat
case "{combined}":
- format = corednslog.CombinedLogFormat
+ format = CombinedLogFormat
default:
format = args[2]
}
}
- rules = append(rules, corednslog.Rule{
+ rules = append(rules, Rule{
NameScope: dns.Fqdn(args[0]),
OutputFile: args[1],
Format: format,
diff --git a/core/setup/log_test.go b/middleware/log/setup_test.go
similarity index 65%
rename from core/setup/log_test.go
rename to middleware/log/setup_test.go
index ad9cb7c3a..0a3ee63fe 100644
--- a/core/setup/log_test.go
+++ b/middleware/log/setup_test.go
@@ -1,99 +1,61 @@
-package setup
+package log
import (
"testing"
"github.com/miekg/coredns/middleware"
- corednslog "github.com/miekg/coredns/middleware/log"
+
+ "github.com/mholt/caddy"
)
-func TestLog(t *testing.T) {
-
- c := NewTestController(`log`)
-
- mid, err := Log(c)
-
- if err != nil {
- t.Errorf("Expected no errors, got: %v", err)
- }
-
- if mid == nil {
- t.Fatal("Expected middleware, was nil instead")
- }
-
- handler := mid(EmptyNext)
- myHandler, ok := handler.(corednslog.Logger)
-
- if !ok {
- t.Fatalf("Expected handler to be type Logger, got: %#v", handler)
- }
-
- if myHandler.Rules[0].NameScope != "." {
- t.Errorf("Expected . as the default NameScope")
- }
- if myHandler.Rules[0].OutputFile != corednslog.DefaultLogFilename {
- t.Errorf("Expected %s as the default OutputFile", corednslog.DefaultLogFilename)
- }
- if myHandler.Rules[0].Format != corednslog.DefaultLogFormat {
- t.Errorf("Expected %s as the default Log Format", corednslog.DefaultLogFormat)
- }
- if myHandler.Rules[0].Roller != nil {
- t.Errorf("Expected Roller to be nil, got: %v", *myHandler.Rules[0].Roller)
- }
- if !SameNext(myHandler.Next, EmptyNext) {
- t.Error("'Next' field of handler was not set properly")
- }
-
-}
-
func TestLogParse(t *testing.T) {
tests := []struct {
inputLogRules string
shouldErr bool
- expectedLogRules []corednslog.Rule
+ expectedLogRules []Rule
}{
- {`log`, false, []corednslog.Rule{{
+ {`log`, false, []Rule{{
NameScope: ".",
- OutputFile: corednslog.DefaultLogFilename,
- Format: corednslog.DefaultLogFormat,
+ OutputFile: DefaultLogFilename,
+ Format: DefaultLogFormat,
}}},
- {`log log.txt`, false, []corednslog.Rule{{
+ {`log log.txt`, false, []Rule{{
NameScope: ".",
OutputFile: "log.txt",
- Format: corednslog.DefaultLogFormat,
+ Format: DefaultLogFormat,
}}},
- {`log example.org log.txt`, false, []corednslog.Rule{{
+ {`log example.org log.txt`, false, []Rule{{
NameScope: "example.org.",
OutputFile: "log.txt",
- Format: corednslog.DefaultLogFormat,
+ Format: DefaultLogFormat,
}}},
- {`log example.org. stdout`, false, []corednslog.Rule{{
+ {`log example.org. stdout`, false, []Rule{{
NameScope: "example.org.",
OutputFile: "stdout",
- Format: corednslog.DefaultLogFormat,
+ Format: DefaultLogFormat,
}}},
- {`log example.org log.txt {common}`, false, []corednslog.Rule{{
+ {`log example.org log.txt {common}`, false, []Rule{{
NameScope: "example.org.",
OutputFile: "log.txt",
- Format: corednslog.CommonLogFormat,
+ Format: CommonLogFormat,
}}},
- {`log example.org accesslog.txt {combined}`, false, []corednslog.Rule{{
+ {`log example.org accesslog.txt {combined}`, false, []Rule{{
NameScope: "example.org.",
OutputFile: "accesslog.txt",
- Format: corednslog.CombinedLogFormat,
+ Format: CombinedLogFormat,
}}},
{`log example.org. log.txt
- log example.net accesslog.txt {combined}`, false, []corednslog.Rule{{
+ log example.net accesslog.txt {combined}`, false, []Rule{{
NameScope: "example.org.",
OutputFile: "log.txt",
- Format: corednslog.DefaultLogFormat,
+ Format: DefaultLogFormat,
}, {
NameScope: "example.net.",
OutputFile: "accesslog.txt",
- Format: corednslog.CombinedLogFormat,
+ Format: CombinedLogFormat,
}}},
{`log example.org stdout {host}
- log example.org log.txt {when}`, false, []corednslog.Rule{{
+ log example.org log.txt {when}`, false, []Rule{{
NameScope: "example.org.",
OutputFile: "stdout",
Format: "{host}",
@@ -102,10 +64,10 @@ func TestLogParse(t *testing.T) {
OutputFile: "log.txt",
Format: "{when}",
}}},
- {`log access.log { rotate { size 2 age 10 keep 3 } }`, false, []corednslog.Rule{{
+ {`log access.log { rotate { size 2 age 10 keep 3 } }`, false, []Rule{{
NameScope: ".",
OutputFile: "access.log",
- Format: corednslog.DefaultLogFormat,
+ Format: DefaultLogFormat,
Roller: &middleware.LogRoller{
MaxSize: 2,
MaxAge: 10,
@@ -115,7 +77,7 @@ func TestLogParse(t *testing.T) {
}}},
}
for i, test := range tests {
- c := NewTestController(test.inputLogRules)
+ c := caddy.NewTestController("dns", test.inputLogRules)
actualLogRules, err := logParse(c)
if err == nil && test.shouldErr {
diff --git a/middleware/metrics/metrics.go b/middleware/metrics/metrics.go
index 1c7db29d2..3a72bc3bb 100644
--- a/middleware/metrics/metrics.go
+++ b/middleware/metrics/metrics.go
@@ -34,7 +34,7 @@ type Metrics struct {
ZoneNames []string
}
-func (m *Metrics) Start() error {
+func (m *Metrics) Startup() error {
m.Once.Do(func() {
define()
diff --git a/middleware/metrics/setup.go b/middleware/metrics/setup.go
new file mode 100644
index 000000000..f31cbd4d5
--- /dev/null
+++ b/middleware/metrics/setup.go
@@ -0,0 +1,84 @@
+package metrics
+
+import (
+ "sync"
+
+ "github.com/miekg/coredns/core/dnsserver"
+ "github.com/miekg/coredns/middleware"
+
+ "github.com/mholt/caddy"
+)
+
+func init() {
+ caddy.RegisterPlugin("prometheus", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
+ m, err := prometheusParse(c)
+ if err != nil {
+ return err
+ }
+
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ m.Next = next
+ return m
+ })
+
+ metricsOnce.Do(func() {
+ c.OnStartup(m.Startup)
+ c.OnShutdown(m.Shutdown)
+ })
+
+ return nil
+}
+
+func prometheusParse(c *caddy.Controller) (Metrics, error) {
+ var (
+ met Metrics
+ err error
+ )
+
+ for c.Next() {
+ if len(met.ZoneNames) > 0 {
+ return Metrics{}, c.Err("metrics: can only have one metrics module per server")
+ }
+ met.ZoneNames = make([]string, len(c.ServerBlockKeys))
+ copy(met.ZoneNames, c.ServerBlockKeys)
+ for i, _ := range met.ZoneNames {
+ met.ZoneNames[i] = middleware.Host(met.ZoneNames[i]).Normalize()
+ }
+ args := c.RemainingArgs()
+
+ switch len(args) {
+ case 0:
+ case 1:
+ met.Addr = args[0]
+ default:
+ return Metrics{}, c.ArgErr()
+ }
+ for c.NextBlock() {
+ switch c.Val() {
+ case "address":
+ args = c.RemainingArgs()
+ if len(args) != 1 {
+ return Metrics{}, c.ArgErr()
+ }
+ met.Addr = args[0]
+ default:
+ return Metrics{}, c.Errf("metrics: unknown item: %s", c.Val())
+ }
+
+ }
+ }
+ if met.Addr == "" {
+ met.Addr = addr
+ }
+ return met, err
+}
+
+var metricsOnce sync.Once
+
+const addr = "localhost:9153"
diff --git a/middleware/pprof/pprof.go b/middleware/pprof/pprof.go
index 42102d198..f538b3091 100644
--- a/middleware/pprof/pprof.go
+++ b/middleware/pprof/pprof.go
@@ -12,7 +12,7 @@ type Handler struct {
mux *http.ServeMux
}
-func (h *Handler) Start() error {
+func (h *Handler) Startup() error {
if ln, err := net.Listen("tcp", addr); err != nil {
log.Printf("[ERROR] Failed to start pprof handler: %s", err)
return err
diff --git a/middleware/pprof/setup.go b/middleware/pprof/setup.go
new file mode 100644
index 000000000..2b6701f35
--- /dev/null
+++ b/middleware/pprof/setup.go
@@ -0,0 +1,40 @@
+package pprof
+
+import (
+ "sync"
+
+ "github.com/mholt/caddy"
+)
+
+func init() {
+ caddy.RegisterPlugin("pprof", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
+ found := false
+ for c.Next() {
+ if found {
+ return c.Err("pprof can only be specified once")
+ }
+ if len(c.RemainingArgs()) != 0 {
+ return c.ArgErr()
+ }
+ if c.NextBlock() {
+ return c.ArgErr()
+ }
+ found = true
+ }
+
+ handler := &Handler{}
+ pprofOnce.Do(func() {
+ c.OnStartup(handler.Startup)
+ c.OnShutdown(handler.Shutdown)
+ })
+
+ return nil
+}
+
+var pprofOnce sync.Once
diff --git a/core/setup/pprof_test.go b/middleware/pprof/setup_test.go
similarity index 78%
rename from core/setup/pprof_test.go
rename to middleware/pprof/setup_test.go
index ac9375af7..af46fd415 100644
--- a/core/setup/pprof_test.go
+++ b/middleware/pprof/setup_test.go
@@ -1,6 +1,10 @@
-package setup
+package pprof
-import "testing"
+import (
+ "testing"
+
+ "github.com/mholt/caddy"
+)
func TestPProf(t *testing.T) {
tests := []struct {
@@ -17,8 +21,8 @@ func TestPProf(t *testing.T) {
pprof`, true},
}
for i, test := range tests {
- c := NewTestController(test.input)
- _, err := PProf(c)
+ c := caddy.NewTestController("dns", test.input)
+ err := setup(c)
if test.shouldErr && err == nil {
t.Errorf("Test %v: Expected error but found nil", i)
} else if !test.shouldErr && err != nil {
diff --git a/middleware/proxy/setup.go b/middleware/proxy/setup.go
new file mode 100644
index 000000000..81dc7777c
--- /dev/null
+++ b/middleware/proxy/setup.go
@@ -0,0 +1,26 @@
+package proxy
+
+import (
+ "github.com/miekg/coredns/core/dnsserver"
+
+ "github.com/mholt/caddy"
+)
+
+func init() {
+ caddy.RegisterPlugin("proxy", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
+ upstreams, err := NewStaticUpstreams(c.Dispenser)
+ if err != nil {
+ return err
+ }
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ return Proxy{Next: next, Client: Clients(), Upstreams: upstreams}
+ })
+
+ return nil
+}
diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go
index 76c27a45a..12ec00d76 100644
--- a/middleware/proxy/upstream.go
+++ b/middleware/proxy/upstream.go
@@ -11,8 +11,9 @@ import (
"sync/atomic"
"time"
- "github.com/miekg/coredns/core/parse"
"github.com/miekg/coredns/middleware"
+
+ "github.com/mholt/caddy/caddyfile"
"github.com/miekg/dns"
)
@@ -43,7 +44,7 @@ type Options struct {
// NewStaticUpstreams parses the configuration input and sets up
// static upstreams for the proxy middleware.
-func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
+func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) {
var upstreams []Upstream
for c.Next() {
upstream := &staticUpstream{
@@ -73,7 +74,7 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
}
for c.NextBlock() {
- if err := parseBlock(&c, upstream); err != nil {
+ if err := parseBlock(c, upstream); err != nil {
return upstreams, err
}
}
@@ -125,7 +126,7 @@ func (u *staticUpstream) Options() Options {
return u.options
}
-func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
+func parseBlock(c caddyfile.Dispenser, u *staticUpstream) error {
switch c.Val() {
case "policy":
if !c.NextArg() {
diff --git a/core/setup/rewrite.go b/middleware/rewrite/setup.go
similarity index 70%
rename from core/setup/rewrite.go
rename to middleware/rewrite/setup.go
index 86bef2ca3..abfc0fbd6 100644
--- a/core/setup/rewrite.go
+++ b/middleware/rewrite/setup.go
@@ -1,34 +1,40 @@
-package setup
+package rewrite
import (
"strconv"
"strings"
- "github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/middleware/rewrite"
+ "github.com/miekg/coredns/core/dnsserver"
+
+ "github.com/mholt/caddy"
)
-// Rewrite configures a new Rewrite middleware instance.
-func Rewrite(c *Controller) (middleware.Middleware, error) {
- rewrites, err := rewriteParse(c)
- if err != nil {
- return nil, err
- }
-
- return func(next middleware.Handler) middleware.Handler {
- return rewrite.Rewrite{
- Next: next,
- Rules: rewrites,
- }
- }, nil
+func init() {
+ caddy.RegisterPlugin("rewrite", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
}
-func rewriteParse(c *Controller) ([]rewrite.Rule, error) {
- var simpleRules []rewrite.Rule
- var regexpRules []rewrite.Rule
+func setup(c *caddy.Controller) error {
+ rewrites, err := rewriteParse(c)
+ if err != nil {
+ return err
+ }
+
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ return Rewrite{Next: next, Rules: rewrites}
+ })
+
+ return nil
+}
+
+func rewriteParse(c *caddy.Controller) ([]Rule, error) {
+ var simpleRules []Rule
+ var regexpRules []Rule
for c.Next() {
- var rule rewrite.Rule
+ var rule Rule
var err error
var base = "."
var pattern, to string
@@ -37,7 +43,7 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) {
args := c.RemainingArgs()
- var ifs []rewrite.If
+ var ifs []If
switch len(args) {
case 1:
@@ -68,7 +74,7 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) {
if len(args1) != 3 {
return nil, c.ArgErr()
}
- ifCond, err := rewrite.NewIf(args1[0], args1[1], args1[2])
+ ifCond, err := NewIf(args1[0], args1[1], args1[2])
if err != nil {
return nil, err
}
@@ -92,14 +98,14 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) {
// TODO(miek): complex rules
base, pattern, to, status, ext, ifs = base, pattern, to, status, ext, ifs
err = err
- // if rule, err = rewrite.NewComplexRule(base, pattern, to, status, ext, ifs); err != nil {
+ // if rule, err = NewComplexRule(base, pattern, to, status, ext, ifs); err != nil {
// return nil, err
// }
regexpRules = append(regexpRules, rule)
// the only unhandled case is 2 and above
default:
- rule = rewrite.NewSimpleRule(args[0], strings.Join(args[1:], " "))
+ rule = NewSimpleRule(args[0], strings.Join(args[1:], " "))
simpleRules = append(simpleRules, rule)
}
}
diff --git a/middleware/roller.go b/middleware/roller.go
index 995cabf91..81ff71c44 100644
--- a/middleware/roller.go
+++ b/middleware/roller.go
@@ -2,10 +2,45 @@ package middleware
import (
"io"
+ "strconv"
+ "github.com/mholt/caddy"
"gopkg.in/natefinch/lumberjack.v2"
)
+func ParseRoller(c *caddy.Controller) (*LogRoller, error) {
+ var size, age, keep int
+ // This is kind of a hack to support nested blocks:
+ // As we are already in a block: either log or errors,
+ // c.nesting > 0 but, as soon as c meets a }, it thinks
+ // the block is over and return false for c.NextBlock.
+ for c.NextBlock() {
+ what := c.Val()
+ if !c.NextArg() {
+ return nil, c.ArgErr()
+ }
+ value := c.Val()
+ var err error
+ switch what {
+ case "size":
+ size, err = strconv.Atoi(value)
+ case "age":
+ age, err = strconv.Atoi(value)
+ case "keep":
+ keep, err = strconv.Atoi(value)
+ }
+ if err != nil {
+ return nil, err
+ }
+ }
+ return &LogRoller{
+ MaxSize: size,
+ MaxAge: age,
+ MaxBackups: keep,
+ LocalTime: true,
+ }, nil
+}
+
// LogRoller implements a middleware that provides a rolling logger.
type LogRoller struct {
Filename string
diff --git a/core/setup/secondary.go b/middleware/secondary/setup.go
similarity index 64%
rename from core/setup/secondary.go
rename to middleware/secondary/setup.go
index e1f54a651..5550cf11d 100644
--- a/core/setup/secondary.go
+++ b/middleware/secondary/setup.go
@@ -1,22 +1,30 @@
-package setup
+package secondary
import (
+ "github.com/miekg/coredns/core/dnsserver"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/file"
- "github.com/miekg/coredns/middleware/secondary"
+
+ "github.com/mholt/caddy"
)
-// Secondary sets up the secondary middleware.
-func Secondary(c *Controller) (middleware.Middleware, error) {
+func init() {
+ caddy.RegisterPlugin("secondary", caddy.Plugin{
+ ServerType: "dns",
+ Action: setup,
+ })
+}
+
+func setup(c *caddy.Controller) error {
zones, err := secondaryParse(c)
if err != nil {
- return nil, err
+ return err
}
// Add startup functions to retrieve the zone and keep it up to date.
for _, n := range zones.Names {
if len(zones.Z[n].TransferFrom) > 0 {
- c.Startup = append(c.Startup, func() error {
+ c.OnStartup(func() error {
zones.Z[n].StartupOnce.Do(func() {
zones.Z[n].TransferIn()
go func() {
@@ -28,19 +36,22 @@ func Secondary(c *Controller) (middleware.Middleware, error) {
}
}
- return func(next middleware.Handler) middleware.Handler {
- return secondary.Secondary{file.File{Next: next, Zones: zones}}
- }, nil
+ dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler {
+ return Secondary{file.File{Next: next, Zones: zones}}
+ })
+ return nil
}
-func secondaryParse(c *Controller) (file.Zones, error) {
+func secondaryParse(c *caddy.Controller) (file.Zones, error) {
z := make(map[string]*file.Zone)
names := []string{}
+ origins := []string{}
for c.Next() {
if c.Val() == "secondary" {
// secondary [origin]
- origins := c.ServerBlockHosts
+ origins = make([]string, len(c.ServerBlockKeys))
+ copy(origins, c.ServerBlockKeys)
args := c.RemainingArgs()
if len(args) > 0 {
origins = args
@@ -52,7 +63,7 @@ func secondaryParse(c *Controller) (file.Zones, error) {
}
for c.NextBlock() {
- t, f, e := transferParse(c)
+ t, f, e := file.TransferParse(c)
if e != nil {
return file.Zones{}, e
}
diff --git a/plugin_generate.go b/plugin_generate.go
new file mode 100644
index 000000000..fe8fef75d
--- /dev/null
+++ b/plugin_generate.go
@@ -0,0 +1,81 @@
+//+build ignore
+
+package main
+
+import (
+ "bytes"
+ "errors"
+ "go/ast"
+ "go/parser"
+ "go/printer"
+ "go/token"
+ "io/ioutil"
+ "log"
+ "strconv"
+)
+
+func AddImportToFile(file, imprt string) ([]byte, error) {
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, file, nil, parser.ParseComments)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, s := range f.Imports {
+ iSpec := &ast.ImportSpec{Path: &ast.BasicLit{Value: s.Path.Value}}
+ if iSpec.Path.Value == strconv.Quote(imprt) {
+ return nil, errors.New("coredns import already found")
+ }
+ }
+
+ for i := 0; i < len(f.Decls); i++ {
+ d := f.Decls[i]
+
+ switch d.(type) {
+ case *ast.FuncDecl:
+ // No action
+ case *ast.GenDecl:
+ dd := d.(*ast.GenDecl)
+
+ // IMPORT Declarations
+ if dd.Tok == token.IMPORT {
+ // Add the new import
+ iSpec := &ast.ImportSpec{Name: &ast.Ident{Name: "_"}, Path: &ast.BasicLit{Value: strconv.Quote(imprt)}}
+ dd.Specs = append(dd.Specs, iSpec)
+ break
+ }
+ }
+ }
+
+ ast.SortImports(fset, f)
+
+ out, err := GenerateFile(fset, f)
+ return out, err
+}
+
+func GenerateFile(fset *token.FileSet, file *ast.File) ([]byte, error) {
+ var output []byte
+ buffer := bytes.NewBuffer(output)
+ if err := printer.Fprint(buffer, fset, file); err != nil {
+ return nil, err
+ }
+
+ return buffer.Bytes(), nil
+}
+
+const (
+ coredns = "github.com/miekg/coredns/core"
+ // If everything is OK and we are sitting in CoreDNS' dir, this is where run.go should be.
+ caddyrun = "../../mholt/caddy/caddy/caddymain/run.go"
+)
+
+func main() {
+ out, err := AddImportToFile(caddyrun, coredns)
+ if err != nil {
+ log.Printf("failed to add import: %s", err)
+ return
+ }
+ if err := ioutil.WriteFile(caddyrun, out, 0644); err != nil {
+ log.Fatalf("failed to write go file: %s", err)
+ }
+}
diff --git a/server/server.go b/server/server.go
index 0ba6f7f01..b0f89468e 100644
--- a/server/server.go
+++ b/server/server.go
@@ -28,13 +28,10 @@ import (
// the same address and the listener may be stopped for
// graceful termination (POSIX only).
type Server struct {
- Addr string // Address we listen on
- mux *dns.ServeMux
- server [2]*dns.Server // by convention 0 is tcp and 1 is udp
-
- tcp net.Listener
- udp net.PacketConn
- listenerMu sync.Mutex // protects listener and packetconn
+ Addr string // Address we listen on
+ mux *dns.ServeMux
+ server [2]*dns.Server // by convention 0 is tcp and 1 is udp
+ listenerMu sync.Mutex // protects listener and packetconn inside server
tls bool // whether this server is serving all HTTPS hosts or not
TLSConfig *tls.Config
@@ -115,16 +112,6 @@ func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server,
return s, nil
}
-// LocalAddr return the addresses where the server is bound to. The TCP listener
-// address is the first returned, the UDP conn address the second.
-func (s *Server) LocalAddr() (net.Addr, net.Addr) {
- s.listenerMu.Lock()
- tcp := s.tcp.Addr()
- udp := s.udp.LocalAddr()
- s.listenerMu.Unlock()
- return tcp, udp
-}
-
// Serve starts the server with an existing listener. It blocks until the server stops.
func (s *Server) Serve(ln net.Listener, pc net.PacketConn) error {
err := s.setup()
@@ -134,9 +121,7 @@ func (s *Server) Serve(ln net.Listener, pc net.PacketConn) error {
}
s.listenerMu.Lock()
s.server[0] = &dns.Server{Listener: ln, Net: "tcp", Handler: s.mux}
- s.tcp = ln
s.server[1] = &dns.Server{PacketConn: pc, Net: "udp", Handler: s.mux}
- s.udp = pc
s.listenerMu.Unlock()
go func() {
@@ -168,9 +153,7 @@ func (s *Server) ListenAndServe() error {
s.listenerMu.Lock()
s.server[0] = &dns.Server{Listener: l, Net: "tcp", Handler: s.mux}
- s.tcp = l
s.server[1] = &dns.Server{PacketConn: pc, Net: "udp", Handler: s.mux}
- s.udp = pc
s.listenerMu.Unlock()
go func() {
@@ -252,17 +235,17 @@ func (s *Server) Stop() (err error) {
// Close the listener now; this stops the server without delay
s.listenerMu.Lock()
- if s.tcp != nil {
- err = s.tcp.Close()
- }
- if s.udp != nil {
- err = s.udp.Close()
- }
+ defer s.listenerMu.Unlock()
for _, s1 := range s.server {
+ if s1.Listener != nil {
+ err = s1.Listener.Close()
+ }
+ if s1.PacketConn != nil {
+ err = s1.PacketConn.Close()
+ }
err = s1.Shutdown()
}
- s.listenerMu.Unlock()
return
}
@@ -280,8 +263,8 @@ func (s *Server) WaitUntilStarted() {
func (s *Server) ListenerFd() *os.File {
s.listenerMu.Lock()
defer s.listenerMu.Unlock()
- if s.tcp != nil {
- file, _ := s.tcp.(*net.TCPListener).File()
+ if s.server[0].Listener != nil {
+ file, _ := s.server[0].Listener.(*net.TCPListener).File()
return file
}
return nil
@@ -293,8 +276,8 @@ func (s *Server) ListenerFd() *os.File {
func (s *Server) PacketConnFd() *os.File {
s.listenerMu.Lock()
defer s.listenerMu.Unlock()
- if s.udp != nil {
- file, _ := s.udp.(*net.UDPConn).File()
+ if s.server[1].PacketConn != nil {
+ file, _ := s.server[1].PacketConn.(*net.UDPConn).File()
return file
}
return nil
diff --git a/test/etcd_test.go b/test/etcd_test.go
index ceec936e7..6beb4cff2 100644
--- a/test/etcd_test.go
+++ b/test/etcd_test.go
@@ -46,13 +46,18 @@ func TestEtcdStubAndProxyLookup(t *testing.T) {
proxy . 8.8.8.8:53
}`
- etc := etcdMiddleware()
- ex, _, udp, err := Server(t, corefile)
+ ex, err := CoreDNSServer(corefile)
if err != nil {
- t.Fatalf("Could get server: %s", err)
+ t.Fatalf("could not get CoreDNS serving instance: %s", err)
+ }
+
+ udp, _ := CoreDNSServerPorts(ex, 0)
+ if udp == "" {
+ t.Fatalf("could not get udp listening port")
}
defer ex.Stop()
+ etc := etcdMiddleware()
log.SetOutput(ioutil.Discard)
var ctx = context.TODO()
diff --git a/test/fail_start_test.go b/test/fail_start_test.go
deleted file mode 100644
index aa1b137af..000000000
--- a/test/fail_start_test.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package test
-
-import (
- "testing"
-
- "github.com/miekg/coredns/core"
-)
-
-// Bind to low port should fail.
-func TestFailStartServer(t *testing.T) {
- corefile := `.:53 {
- chaos CoreDNS-001 miek@miek.nl
-}
-`
- srv, _ := core.TestServer(t, corefile)
- err := srv.ListenAndServe()
- if err == nil {
- srv.Stop()
- t.Fatalf("Low port startup should fail")
- }
-}
diff --git a/test/file.go b/test/file.go
new file mode 100644
index 000000000..b6068a32b
--- /dev/null
+++ b/test/file.go
@@ -0,0 +1,20 @@
+package test
+
+import (
+ "io/ioutil"
+ "os"
+ "testing"
+)
+
+// TempFile will create a temporary file on disk and returns the name and a cleanup function to remove it later.
+func TempFile(t *testing.T, dir, content string) (string, func(), error) {
+ f, err := ioutil.TempFile(dir, "go-test-tmpfile")
+ if err != nil {
+ return "", nil, err
+ }
+ if err := ioutil.WriteFile(f.Name(), []byte(content), 0644); err != nil {
+ return "", nil, err
+ }
+ rmFunc := func() { os.Remove(f.Name()) }
+ return f.Name(), rmFunc, nil
+}
diff --git a/test/file_test.go b/test/file_test.go
new file mode 100644
index 000000000..950ea7cff
--- /dev/null
+++ b/test/file_test.go
@@ -0,0 +1,11 @@
+package test
+
+import "testing"
+
+func TestTempFile(t *testing.T) {
+ _, f, e := TempFile(t, ".", "test")
+ if e != nil {
+ t.Fatalf("failed to create temp file: %s", e)
+ }
+ defer f()
+}
diff --git a/test/helpers.go b/test/helpers.go
new file mode 100644
index 000000000..01a6f156b
--- /dev/null
+++ b/test/helpers.go
@@ -0,0 +1,250 @@
+package test
+
+import (
+ "testing"
+
+ "github.com/miekg/dns"
+ "golang.org/x/net/context"
+)
+
+type Sect int
+
+const (
+ Answer Sect = iota
+ Ns
+ Extra
+)
+
+type RRSet []dns.RR
+
+func (p RRSet) Len() int { return len(p) }
+func (p RRSet) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+func (p RRSet) Less(i, j int) bool { return p[i].String() < p[j].String() }
+
+// If the TTL of a record is 303 we don't care what the TTL is.
+type Case struct {
+ Qname string
+ Qtype uint16
+ Rcode int
+ Do bool
+ Answer []dns.RR
+ Ns []dns.RR
+ Extra []dns.RR
+}
+
+func (c Case) Msg() *dns.Msg {
+ m := new(dns.Msg)
+ m.SetQuestion(dns.Fqdn(c.Qname), c.Qtype)
+ if c.Do {
+ o := new(dns.OPT)
+ o.Hdr.Name = "."
+ o.Hdr.Rrtype = dns.TypeOPT
+ o.SetDo()
+ o.SetUDPSize(4096)
+ m.Extra = []dns.RR{o}
+ }
+ return m
+}
+
+func A(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) }
+func AAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) }
+func CNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) }
+func SRV(rr string) *dns.SRV { r, _ := dns.NewRR(rr); return r.(*dns.SRV) }
+func SOA(rr string) *dns.SOA { r, _ := dns.NewRR(rr); return r.(*dns.SOA) }
+func NS(rr string) *dns.NS { r, _ := dns.NewRR(rr); return r.(*dns.NS) }
+func PTR(rr string) *dns.PTR { r, _ := dns.NewRR(rr); return r.(*dns.PTR) }
+func TXT(rr string) *dns.TXT { r, _ := dns.NewRR(rr); return r.(*dns.TXT) }
+func MX(rr string) *dns.MX { r, _ := dns.NewRR(rr); return r.(*dns.MX) }
+func RRSIG(rr string) *dns.RRSIG { r, _ := dns.NewRR(rr); return r.(*dns.RRSIG) }
+func NSEC(rr string) *dns.NSEC { r, _ := dns.NewRR(rr); return r.(*dns.NSEC) }
+func DNSKEY(rr string) *dns.DNSKEY { r, _ := dns.NewRR(rr); return r.(*dns.DNSKEY) }
+
+func OPT(bufsize int, do bool) *dns.OPT {
+ o := new(dns.OPT)
+ o.Hdr.Name = "."
+ o.Hdr.Rrtype = dns.TypeOPT
+ o.SetVersion(0)
+ o.SetUDPSize(uint16(bufsize))
+ if do {
+ o.SetDo()
+ }
+ return o
+}
+
+func Header(t *testing.T, tc Case, resp *dns.Msg) bool {
+ if resp.Rcode != tc.Rcode {
+ t.Errorf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode])
+ return false
+ }
+
+ if len(resp.Answer) != len(tc.Answer) {
+ t.Errorf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer))
+ return false
+ }
+ if len(resp.Ns) != len(tc.Ns) {
+ t.Errorf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns))
+ return false
+ }
+ if len(resp.Extra) != len(tc.Extra) {
+ t.Errorf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra))
+ return false
+ }
+ return true
+}
+
+func Section(t *testing.T, tc Case, sect Sect, rr []dns.RR) bool {
+ section := []dns.RR{}
+ switch sect {
+ case 0:
+ section = tc.Answer
+ case 1:
+ section = tc.Ns
+ case 2:
+ section = tc.Extra
+ }
+
+ for i, a := range rr {
+ if a.Header().Name != section[i].Header().Name {
+ t.Errorf("rr %d should have a Header Name of %q, but has %q", i, section[i].Header().Name, a.Header().Name)
+ return false
+ }
+ // 303 signals: don't care what the ttl is.
+ if section[i].Header().Ttl != 303 && a.Header().Ttl != section[i].Header().Ttl {
+ if _, ok := section[i].(*dns.OPT); !ok {
+ // we check edns0 bufize on this one
+ t.Errorf("rr %d should have a Header TTL of %d, but has %d", i, section[i].Header().Ttl, a.Header().Ttl)
+ return false
+ }
+ }
+ if a.Header().Rrtype != section[i].Header().Rrtype {
+ t.Errorf("rr %d should have a header rr type of %d, but has %d", i, section[i].Header().Rrtype, a.Header().Rrtype)
+ return false
+ }
+
+ switch x := a.(type) {
+ case *dns.SRV:
+ if x.Priority != section[i].(*dns.SRV).Priority {
+ t.Errorf("rr %d should have a Priority of %d, but has %d", i, section[i].(*dns.SRV).Priority, x.Priority)
+ return false
+ }
+ if x.Weight != section[i].(*dns.SRV).Weight {
+ t.Errorf("rr %d should have a Weight of %d, but has %d", i, section[i].(*dns.SRV).Weight, x.Weight)
+ return false
+ }
+ if x.Port != section[i].(*dns.SRV).Port {
+ t.Errorf("rr %d should have a Port of %d, but has %d", i, section[i].(*dns.SRV).Port, x.Port)
+ return false
+ }
+ if x.Target != section[i].(*dns.SRV).Target {
+ t.Errorf("rr %d should have a Target of %q, but has %q", i, section[i].(*dns.SRV).Target, x.Target)
+ return false
+ }
+ case *dns.RRSIG:
+ if x.TypeCovered != section[i].(*dns.RRSIG).TypeCovered {
+ t.Errorf("rr %d should have a TypeCovered of %d, but has %d", i, section[i].(*dns.RRSIG).TypeCovered, x.TypeCovered)
+ return false
+ }
+ if x.Labels != section[i].(*dns.RRSIG).Labels {
+ t.Errorf("rr %d should have a Labels of %d, but has %d", i, section[i].(*dns.RRSIG).Labels, x.Labels)
+ return false
+ }
+ if x.SignerName != section[i].(*dns.RRSIG).SignerName {
+ t.Errorf("rr %d should have a SignerName of %d, but has %d", i, section[i].(*dns.RRSIG).SignerName, x.SignerName)
+ return false
+ }
+ case *dns.NSEC:
+ if x.NextDomain != section[i].(*dns.NSEC).NextDomain {
+ t.Errorf("rr %d should have a NextDomain of %d, but has %d", i, section[i].(*dns.NSEC).NextDomain, x.NextDomain)
+ return false
+ }
+ // TypeBitMap
+ case *dns.A:
+ if x.A.String() != section[i].(*dns.A).A.String() {
+ t.Errorf("rr %d should have a Address of %q, but has %q", i, section[i].(*dns.A).A.String(), x.A.String())
+ return false
+ }
+ case *dns.AAAA:
+ if x.AAAA.String() != section[i].(*dns.AAAA).AAAA.String() {
+ t.Errorf("rr %d should have a Address of %q, but has %q", i, section[i].(*dns.AAAA).AAAA.String(), x.AAAA.String())
+ return false
+ }
+ case *dns.TXT:
+ for j, txt := range x.Txt {
+ if txt != section[i].(*dns.TXT).Txt[j] {
+ t.Errorf("rr %d should have a Txt of %q, but has %q", i, section[i].(*dns.TXT).Txt[j], txt)
+ return false
+ }
+ }
+ case *dns.SOA:
+ tt := section[i].(*dns.SOA)
+ if x.Ns != tt.Ns {
+ t.Errorf("SOA nameserver should be %q, but is %q", x.Ns, tt.Ns)
+ return false
+ }
+ case *dns.PTR:
+ tt := section[i].(*dns.PTR)
+ if x.Ptr != tt.Ptr {
+ t.Errorf("PTR ptr should be %q, but is %q", x.Ptr, tt.Ptr)
+ return false
+ }
+ case *dns.CNAME:
+ tt := section[i].(*dns.CNAME)
+ if x.Target != tt.Target {
+ t.Errorf("CNAME target should be %q, but is %q", x.Target, tt.Target)
+ return false
+ }
+ case *dns.MX:
+ tt := section[i].(*dns.MX)
+ if x.Mx != tt.Mx {
+ t.Errorf("MX Mx should be %q, but is %q", x.Mx, tt.Mx)
+ return false
+ }
+ if x.Preference != tt.Preference {
+ t.Errorf("MX Preference should be %q, but is %q", x.Preference, tt.Preference)
+ return false
+ }
+ case *dns.NS:
+ tt := section[i].(*dns.NS)
+ if x.Ns != tt.Ns {
+ t.Errorf("NS nameserver should be %q, but is %q", x.Ns, tt.Ns)
+ return false
+ }
+ case *dns.OPT:
+ tt := section[i].(*dns.OPT)
+ if x.UDPSize() != tt.UDPSize() {
+ t.Errorf("OPT UDPSize should be %d, but is %d", tt.UDPSize(), x.UDPSize())
+ return false
+ }
+ if x.Do() != tt.Do() {
+ t.Errorf("OPT DO should be %t, but is %t", tt.Do(), x.Do())
+ return false
+ }
+ }
+ }
+ return true
+}
+
+func ErrorHandler() Handler {
+ return HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ m := new(dns.Msg)
+ m.SetRcode(r, dns.RcodeServerFailure)
+ w.WriteMsg(m)
+ return dns.RcodeServerFailure, nil
+ })
+}
+
+// Copied here to prevent an import cycle.
+type (
+ // HandlerFunc is a convenience type like dns.HandlerFunc, except
+ // ServeDNS returns an rcode and an error.
+ HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
+
+ Handler interface {
+ ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
+ }
+)
+
+// ServeDNS implements the Handler interface.
+func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ return f(ctx, w, r)
+}
diff --git a/test/helpers_test.go b/test/helpers_test.go
new file mode 100644
index 000000000..2aa1fe106
--- /dev/null
+++ b/test/helpers_test.go
@@ -0,0 +1,5 @@
+package test
+
+import "testing"
+
+func TestA(t *testing.T) { A("miek.nl. IN A 127.0.0.1") } // should not crash
diff --git a/test/kubernetes_test.go b/test/kubernetes_test.go
index 55939e350..009b50657 100644
--- a/test/kubernetes_test.go
+++ b/test/kubernetes_test.go
@@ -3,12 +3,12 @@
package test
import (
- "fmt"
"io/ioutil"
"log"
"testing"
"github.com/miekg/coredns/middleware/kubernetes/k8stest"
+
"github.com/miekg/dns"
)
@@ -64,9 +64,8 @@ var testdataLookupSRV = []struct {
}
func TestK8sIntegration(t *testing.T) {
- t.Log(" === RUN testLookupA")
+ // subtests here (Go 1.7 feature).
testLookupA(t)
- t.Log(" === RUN testLookupSRV")
testLookupSRV(t)
}
@@ -75,7 +74,7 @@ func testLookupA(t *testing.T) {
t.Skip("Skipping Kubernetes Integration tests. Kubernetes is not running")
}
- coreFile :=
+ corefile :=
`.:0 {
kubernetes coredns.local {
endpoint http://localhost:8080
@@ -83,16 +82,20 @@ func testLookupA(t *testing.T) {
}
`
- server, _, udp, err := Server(t, coreFile)
+ server, err := CoreDNSServer(corefile)
if err != nil {
- t.Fatal("Could not get server: %s", err)
+ t.Fatalf("could not get CoreDNS serving instance: %s", err)
+ }
+
+ udp, _ := CoreDNSServerPorts(server, 0)
+ if udp == "" {
+ t.Fatalf("could not get udp listening port")
}
defer server.Stop()
log.SetOutput(ioutil.Discard)
for _, testData := range testdataLookupA {
- t.Logf("[log] Testing query string: '%v'\n", testData.Query)
dnsClient := new(dns.Client)
dnsMessage := new(dns.Msg)
@@ -125,7 +128,7 @@ func testLookupSRV(t *testing.T) {
t.Skip("Skipping Kubernetes Integration tests. Kubernetes is not running")
}
- coreFile :=
+ corefile :=
`.:0 {
kubernetes coredns.local {
endpoint http://localhost:8080
@@ -133,9 +136,13 @@ func testLookupSRV(t *testing.T) {
}
`
- server, _, udp, err := Server(t, coreFile)
+ server, err := CoreDNSServer(corefile)
if err != nil {
- t.Fatal("Could not get server: %s", err)
+ t.Fatalf("could not get CoreDNS serving instance: %s", err)
+ }
+ udp, _ := CoreDNSServerPorts(server, 0)
+ if udp == "" {
+ t.Fatalf("could not get udp listening port")
}
defer server.Stop()
@@ -144,7 +151,6 @@ func testLookupSRV(t *testing.T) {
// TODO: Add checks for A records in additional section
for _, testData := range testdataLookupSRV {
- t.Logf("[log] Testing query string: '%v'\n", testData.Query)
dnsClient := new(dns.Client)
dnsMessage := new(dns.Msg)
@@ -158,7 +164,6 @@ func testLookupSRV(t *testing.T) {
// Count SRV records in the answer section
srvRecordCount := 0
for _, a := range res.Answer {
- fmt.Printf("RR: %v\n", a)
if a.Header().Rrtype == dns.TypeSRV {
srvRecordCount++
}
diff --git a/test/middleware_dnssec_test.go b/test/middleware_dnssec_test.go
index 434e7ad56..afde72a54 100644
--- a/test/middleware_dnssec_test.go
+++ b/test/middleware_dnssec_test.go
@@ -29,11 +29,12 @@ func TestLookupBalanceRewriteCacheDnssec(t *testing.T) {
loadbalance
}
`
- ex, _, udp, err := Server(t, corefile)
+ ex, err := CoreDNSServer(corefile)
if err != nil {
- t.Errorf("Could get server to start: %s", err)
- return
+ t.Fatalf("could not get CoreDNS serving instance: %s", err)
}
+
+ udp, _ := CoreDNSServerPorts(ex, 0)
defer ex.Stop()
log.SetOutput(ioutil.Discard)
diff --git a/test/middleware_test.go b/test/middleware_test.go
index ec90f0d71..f7fefe78a 100644
--- a/test/middleware_test.go
+++ b/test/middleware_test.go
@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns"
)
-func BenchmarkLookupBalanceRewriteCache(b *testing.B) {
+func benchmarkLookupBalanceRewriteCache(b *testing.B) {
t := new(testing.T)
name, rm, err := test.TempFile(t, ".", exampleOrg)
if err != nil {
@@ -24,10 +24,12 @@ func BenchmarkLookupBalanceRewriteCache(b *testing.B) {
loadbalance
}
`
- ex, _, udp, err := Server(t, corefile)
+
+ ex, err := CoreDNSServer(corefile)
if err != nil {
- t.Fatalf("Could get server: %s", err)
+ t.Fatalf("could not get CoreDNS serving instance: %s", err)
}
+ udp, _ := CoreDNSServerPorts(ex, 0)
defer ex.Stop()
log.SetOutput(ioutil.Discard)
diff --git a/test/proxy_test.go b/test/proxy_test.go
index 56ef159fb..ca04e1ae8 100644
--- a/test/proxy_test.go
+++ b/test/proxy_test.go
@@ -28,14 +28,20 @@ func TestLookupProxy(t *testing.T) {
defer rm()
corefile := `example.org:0 {
- file ` + name + `
+ file ` + name + `
}
`
- ex, _, udp, err := Server(t, corefile)
+
+ i, err := CoreDNSServer(corefile)
if err != nil {
- t.Fatalf("Could get server: %s", err)
+ t.Fatalf("could not get CoreDNS serving instance: %s", err)
}
- defer ex.Stop()
+
+ udp, _ := CoreDNSServerPorts(i, 0)
+ if udp == "" {
+ t.Fatalf("could not get udp listening port")
+ }
+ defer i.Stop()
log.SetOutput(ioutil.Discard)
@@ -43,8 +49,7 @@ func TestLookupProxy(t *testing.T) {
state := middleware.State{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
resp, err := p.Lookup(state, "example.org.", dns.TypeA)
if err != nil {
- t.Error("Expected to receive reply, but didn't")
- return
+ t.Fatal("Expected to receive reply, but didn't")
}
// expect answer section with A record in it
if len(resp.Answer) == 0 {
diff --git a/test/responsewriter.go b/test/responsewriter.go
new file mode 100644
index 000000000..fb70d7e8d
--- /dev/null
+++ b/test/responsewriter.go
@@ -0,0 +1,28 @@
+package test
+
+import (
+ "net"
+
+ "github.com/miekg/dns"
+)
+
+type ResponseWriter struct{}
+
+func (t *ResponseWriter) LocalAddr() net.Addr {
+ ip := net.ParseIP("127.0.0.1")
+ port := 53
+ return &net.UDPAddr{IP: ip, Port: port, Zone: ""}
+}
+
+func (t *ResponseWriter) RemoteAddr() net.Addr {
+ ip := net.ParseIP("10.240.0.1")
+ port := 40212
+ return &net.UDPAddr{IP: ip, Port: port, Zone: ""}
+}
+
+func (t *ResponseWriter) WriteMsg(m *dns.Msg) error { return nil }
+func (t *ResponseWriter) Write(buf []byte) (int, error) { return len(buf), nil }
+func (t *ResponseWriter) Close() error { return nil }
+func (t *ResponseWriter) TsigStatus() error { return nil }
+func (t *ResponseWriter) TsigTimersOnly(bool) { return }
+func (t *ResponseWriter) Hijack() { return }
diff --git a/test/server.go b/test/server.go
new file mode 100644
index 000000000..324ffdd8f
--- /dev/null
+++ b/test/server.go
@@ -0,0 +1,91 @@
+package test
+
+import (
+ "net"
+ "sync"
+ "testing"
+ "time"
+
+ _ "github.com/miekg/coredns/core"
+
+ "github.com/mholt/caddy"
+ "github.com/miekg/dns"
+)
+
+func TCPServer(t *testing.T, laddr string) (*dns.Server, string, error) {
+ l, err := net.Listen("tcp", laddr)
+ if err != nil {
+ return nil, "", err
+ }
+
+ server := &dns.Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
+
+ waitLock := sync.Mutex{}
+ waitLock.Lock()
+ server.NotifyStartedFunc = func() { t.Logf("started TCP server on %s", l.Addr()); waitLock.Unlock() }
+
+ go func() {
+ server.ActivateAndServe()
+ l.Close()
+ }()
+
+ waitLock.Lock()
+ return server, l.Addr().String(), nil
+}
+
+func UDPServer(t *testing.T, laddr string) (*dns.Server, string, error) {
+ pc, err := net.ListenPacket("udp", laddr)
+ if err != nil {
+ return nil, "", err
+ }
+ server := &dns.Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
+
+ waitLock := sync.Mutex{}
+ waitLock.Lock()
+ server.NotifyStartedFunc = func() { t.Logf("started UDP server on %s", pc.LocalAddr()); waitLock.Unlock() }
+
+ go func() {
+ server.ActivateAndServe()
+ pc.Close()
+ }()
+
+ waitLock.Lock()
+ return server, pc.LocalAddr().String(), nil
+}
+
+// CoreDNSServer returns a test server. It just takes a normal Corefile as input.
+func CoreDNSServer(corefile string) (*caddy.Instance, error) { return caddy.Start(NewInput(corefile)) }
+
+// CoreDNSSserverStop stops a server.
+func CoreDNSServerStop(i *caddy.Instance) { i.Stop() }
+
+// CoreDNSServeRPorts returns the ports the instance is listening on. The integer k indicates
+// which ServerListener you want.
+func CoreDNSServerPorts(i *caddy.Instance, k int) (udp, tcp string) {
+ srvs := i.Servers()
+ if len(srvs) < k+1 {
+ return "", ""
+ }
+ u := srvs[k].LocalAddr()
+ t := srvs[k].Addr()
+
+ if u != nil {
+ udp = u.String()
+ }
+ if t != nil {
+ tcp = t.String()
+ }
+ return
+}
+
+type Input struct {
+ corefile []byte
+}
+
+func NewInput(corefile string) *Input {
+ return &Input{corefile: []byte(corefile)}
+}
+
+func (i *Input) Body() []byte { return i.corefile }
+func (i *Input) Path() string { return "Corefile" }
+func (i *Input) ServerType() string { return "dns" }
diff --git a/test/server_test.go b/test/server_test.go
index 6a86d022b..a03285bf3 100644
--- a/test/server_test.go
+++ b/test/server_test.go
@@ -12,24 +12,28 @@ func TestProxyToChaosServer(t *testing.T) {
chaos CoreDNS-001 miek@miek.nl
}
`
- chaos, tcpCH, udpCH, err := Server(t, corefile)
+ chaos, err := CoreDNSServer(corefile)
if err != nil {
- t.Fatalf("Could get server: %s", err)
+ t.Fatalf("could not get CoreDNS serving instance: %s", err)
}
+
+ udpChaos, tcpChaos := CoreDNSServerPorts(chaos, 0)
defer chaos.Stop()
corefileProxy := `.:0 {
- proxy . ` + udpCH + `
+ proxy . ` + udpChaos + `
}
`
- proxy, _, udp, err := Server(t, corefileProxy)
+ proxy, err := CoreDNSServer(corefileProxy)
if err != nil {
- t.Fatalf("Could get server: %s", err)
+ t.Fatalf("could not get CoreDNS serving instance")
}
+
+ udp, _ := CoreDNSServerPorts(proxy, 0)
defer proxy.Stop()
- chaosTest(t, udpCH, "udp")
- chaosTest(t, tcpCH, "tcp")
+ chaosTest(t, udpChaos, "udp")
+ chaosTest(t, tcpChaos, "tcp")
chaosTest(t, udp, "udp")
// chaosTest(t, tcp, "tcp"), commented out because we use the original transport to reach the
diff --git a/test/tests.go b/test/tests.go
index d38bf955f..0f9d12bae 100644
--- a/test/tests.go
+++ b/test/tests.go
@@ -1,12 +1,7 @@
package test
import (
- "testing"
- "time"
-
- "github.com/miekg/coredns/core"
"github.com/miekg/coredns/middleware"
- "github.com/miekg/coredns/server"
"github.com/miekg/dns"
)
@@ -25,16 +20,3 @@ func Exchange(m *dns.Msg, server, net string) (*dns.Msg, error) {
c.Net = net
return middleware.Exchange(c, m, server)
}
-
-// Server returns a test server and the tcp and udp listeners addresses.
-func Server(t *testing.T, corefile string) (*server.Server, string, string, error) {
- srv, err := core.TestServer(t, corefile)
- if err != nil {
- return nil, "", "", err
- }
- go srv.ListenAndServe()
-
- time.Sleep(1 * time.Second) // yeah... I regret nothing
- tcp, udp := srv.LocalAddr()
- return srv, tcp.String(), udp.String(), nil
-}