diff --git a/.travis.yml b/.travis.yml index bcd0ff760..d10764882 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,7 +12,7 @@ go: # In the Travis VM-based build environment, IPv6 networking is not # enabled by default. The sysctl operations below enable IPv6. # IPv6 is needed by some of the CoreDNS test cases. The VM environment -# is needed to have access to sudo in the test environment. Sudo is +# is needed to have access to sudo in the test environment. Sudo is # needed to have docker in the test environment. Docker is needed to # launch a kubernetes instance in the test environment. # (Dependencies are fun! :) ) @@ -38,9 +38,9 @@ before_script: - if which docker &>/dev/null ; then docker pull gcr.io/google_containers/hyperkube-amd64:v1.2.4 ; docker ps -a ; fi - if which docker &>/dev/null ; then ./contrib/kubernetes/testscripts/start_k8s_with_services.sh ; docker ps -a ; fi # Get golang dependencies, and build coredns binary - - go get -v -d + - go get -v -d ./... - go get github.com/coreos/go-etcd/etcd - - go build -v -ldflags="-s -w" + #- go build -v -ldflags="-s -w" script: - go test -tags etcd -race -bench=. ./... diff --git a/Makefile b/Makefile index d8cf42b45..d455840a6 100644 --- a/Makefile +++ b/Makefile @@ -1,39 +1,63 @@ #BUILD_VERBOSE := BUILD_VERBOSE := -v -TEST_VERBOSE := +#TEST_VERBOSE := TEST_VERBOSE := -v DOCKER_IMAGE_NAME := $$USER/coredns +all: coredns -all: +# Phony this to ensure we always build the binary. +# TODO: Add .go file dependencies. +.PHONY: coredns +coredns: generate deps go build $(BUILD_VERBOSE) -ldflags="-s -w" .PHONY: docker -docker: all - GOOS=linux go build -a -tags netgo -installsuffix netgo -ldflags="-s -w" +docker: deps + CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" docker build -t $(DOCKER_IMAGE_NAME) . +../../mholt/caddy: + # Get caddy so we can generate into that codebase + # before getting all other dependencies. + go get ${BUILD_VERBOSE} github.com/mholt/caddy + +.PHONY: generate +generate: ../../mholt/caddy + go generate $(BUILD_VERSOSE) + .PHONY: deps -deps: +deps: generate go get ${BUILD_VERBOSE} .PHONY: test -test: +test: deps go test $(TEST_VERBOSE) ./... .PHONY: testk8s -testk8s: +testk8s: deps # With -args --v=100 the k8s API response data will be printed in the log: #go test $(TEST_VERBOSE) -tags=k8s -run 'TestK8sIntegration' ./test -args --v=100 # Without the k8s API response data: go test $(TEST_VERBOSE) -tags=k8s -run 'TestK8sIntegration' ./test .PHONY: testk8s-setup -testk8s-setup: - go test -v ./core/setup -run TestKubernetes +testk8s-setup: deps + go test -v ./middleware/kubernetes/... -run TestKubernetes .PHONY: clean clean: go clean + rm -f coredns + +.PHONY: distclean +distclean: clean + # Clean all dependencies and build artifacts + find $(GOPATH)/pkg -maxdepth 1 -mindepth 1 | xargs rm -rf + find $(GOPATH)/bin -maxdepth 1 -mindepth 1 | xargs rm -rf + + find $(GOPATH)/src -maxdepth 1 -mindepth 1 | grep -v github | xargs rm -rf + find $(GOPATH)/src -maxdepth 2 -mindepth 2 | grep -v miekg | xargs rm -rf + find $(GOPATH)/src/github.com/miekg -maxdepth 1 -mindepth 1 \! -name \*coredns\* | xargs rm -rf diff --git a/README.md b/README.md index e887e0875..47824aeb2 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ # CoreDNS CoreDNS is DNS server that started as a fork of [Caddy](https://github.com/mholt/caddy/). It has the -same model: it chains middleware. +same model: it chains middleware. In fact to similar that CoreDNS is now a server type plugin for +CAddy, i.e. you'll need Caddy to compile CoreDNS. CoreDNS is the successor of [SkyDNS](https://github.com/skynetservices/skydns). SkyDNS is a thin layer that exposes services in etcd in the DNS. CoreDNS builds on this idea and is a generic DNS @@ -38,7 +39,7 @@ There are still few [issues](https://github.com/miekg/coredns/issues), and work things fast and reduce the memory usage. All in all, CoreDNS should be able to provide you with enough functionality to replace parts of -BIND9, Knot, NSD or PowerDNS. +BIND9, Knot, NSD or PowerDNS and SkyDNS. Most documentation is in the source and some blog articles can be [found here](https://miek.nl/tags/coredns/). If you do want to use CoreDNS in production, please let us know and how we can help. @@ -46,6 +47,25 @@ know and how we can help. is also full of examples on how to structure a Corefile (renamed from Caddyfile when I forked it). +## Compilation + +CoreDNS (as a servertype plugin for Caddy) has a hard dependency on Caddy - this is *almost* like +the normal Go dependencies, but with a small twist, caddy (the source) need to know that CoreDNS +exists and for this we need to add 1 line `_ "github.com/miekg/coredns/core"` to file in caddy. + +You have the source of CoreDNS, this should preferably be downloaded under your `$GOPATH`. Get all +dependencies: + + go get ./... + +Then, execute `go generate`, this will patch Caddy to add CoreDNS, and then `go build` as you would +normally do: + + go generate + go build + +Should yield a `coredns` binary. + ## Examples Start a simple proxy: @@ -95,15 +115,17 @@ All the above examples are possible with the *current* CoreDNS. ## What remains to be done -* Website? -* Logo? * Optimizations. * Load testing. * The [issues](https://github.com/miekg/coredns/issues). -## Blog +## Blog and Contact + +Website: +Twitter: `@coredns.io` +Docs: +Github: - ## Systemd service file @@ -123,7 +145,7 @@ LimitNOFILE=8192 User=coredns WorkingDirectory=/home/coredns ExecStartPre=/sbin/setcap cap_net_bind_service=+ep /opt/bin/coredns -ExecStart=/opt/bin/coredns -pidfile /home/coredns/coredns.pid -conf=/etc/coredns +ExecStart=/opt/bin/coredns -pidfile /home/coredns/coredns.pid -conf=/etc/coredns/Corefile ExecReload=/bin/kill -SIGUSR1 $MAINPID Restart=on-failure diff --git a/core/assets/path.go b/core/assets/path.go deleted file mode 100644 index abe4638d0..000000000 --- a/core/assets/path.go +++ /dev/null @@ -1,29 +0,0 @@ -package assets - -import ( - "os" - "path/filepath" - "runtime" -) - -// Path returns the path to the folder -// where the application may store data. This -// currently resolves to ~/.coredns -func Path() string { - return filepath.Join(userHomeDir(), ".coredns") -} - -// userHomeDir returns the user's home directory according to -// environment variables. -// -// Credit: http://stackoverflow.com/a/7922977/1048862 -func userHomeDir() string { - if runtime.GOOS == "windows" { - home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") - if home == "" { - home = os.Getenv("USERPROFILE") - } - return home - } - return os.Getenv("HOME") -} diff --git a/core/assets/path_test.go b/core/assets/path_test.go deleted file mode 100644 index 6f2c0bfb7..000000000 --- a/core/assets/path_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package assets - -import ( - "strings" - "testing" -) - -func TestPath(t *testing.T) { - if actual := Path(); !strings.HasSuffix(actual, ".coredns") { - t.Errorf("Expected path to be a .coredns folder, got: %v", actual) - } -} diff --git a/core/caddy_test.go b/core/caddy_test.go deleted file mode 100644 index 89ec3d045..000000000 --- a/core/caddy_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package core - -/* -func TestCaddyStartStop(t *testing.T) { - corefile := "localhost:1984" - - for i := 0; i < 2; i++ { - err := Start(CorefileInput{Contents: []byte(corefile)}) - if err != nil { - t.Fatalf("Error starting, iteration %d: %v", i, err) - } - - client := http.Client{ - Timeout: time.Duration(2 * time.Second), - } - resp, err := client.Get("http://localhost:1984") - if err != nil { - t.Fatalf("Expected GET request to succeed (iteration %d), but it failed: %v", i, err) - } - resp.Body.Close() - - err = Stop() - if err != nil { - t.Fatalf("Error stopping, iteration %d: %v", i, err) - } - } -} -*/ diff --git a/core/config.go b/core/config.go deleted file mode 100644 index eea099a70..000000000 --- a/core/config.go +++ /dev/null @@ -1,340 +0,0 @@ -package core - -import ( - "bytes" - "fmt" - "io" - "log" - "net" - "sync" - - "github.com/miekg/coredns/core/parse" - "github.com/miekg/coredns/core/setup" - "github.com/miekg/coredns/server" -) - -const ( - // DefaultConfigFile is the name of the configuration file that is loaded - // by default if no other file is specified. - DefaultConfigFile = "Corefile" -) - -func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Config, []parse.ServerBlock, int, error) { - var configs []server.Config - - // Each server block represents similar hosts/addresses, since they - // were grouped together in the Corefile. - serverBlocks, err := parse.ServerBlocks(filename, input, true) - if err != nil { - return nil, nil, 0, err - } - if len(serverBlocks) == 0 { - newInput := DefaultInput() - serverBlocks, err = parse.ServerBlocks(newInput.Path(), bytes.NewReader(newInput.Body()), true) - if err != nil { - return nil, nil, 0, err - } - } - - var lastDirectiveIndex int // we set up directives in two parts; this stores where we left off - - // Iterate each server block and make a config for each one, - // executing the directives that were parsed in order up to the tls - // directive; this is because we must activate Let's Encrypt. - for i, sb := range serverBlocks { - onces := makeOnces() - storages := makeStorages() - - for j, addr := range sb.Addresses { - config := server.Config{ - Host: addr.Host, - Port: addr.Port, - Root: Root, - ConfigFile: filename, - AppName: AppName, - AppVersion: AppVersion, - } - - // It is crucial that directives are executed in the proper order. - for k, dir := range directiveOrder { - // Execute directive if it is in the server block - if tokens, ok := sb.Tokens[dir.name]; ok { - // Each setup function gets a controller, from which setup functions - // get access to the config, tokens, and other state information useful - // to set up its own host only. - controller := &setup.Controller{ - Config: &config, - Dispenser: parse.NewDispenserTokens(filename, tokens), - OncePerServerBlock: func(f func() error) error { - var err error - onces[dir.name].Do(func() { - err = f() - }) - return err - }, - ServerBlockIndex: i, - ServerBlockHostIndex: j, - ServerBlockHosts: sb.HostList(), - ServerBlockStorage: storages[dir.name], - } - // execute setup function and append middleware handler, if any - midware, err := dir.setup(controller) - if err != nil { - return nil, nil, lastDirectiveIndex, err - } - if midware != nil { - config.Middleware = append(config.Middleware, midware) - } - storages[dir.name] = controller.ServerBlockStorage // persist for this server block - } - - // Stop after TLS setup, since we need to activate Let's Encrypt before continuing; - // it makes some changes to the configs that middlewares might want to know about. - if dir.name == "tls" { - lastDirectiveIndex = k - break - } - } - - configs = append(configs, config) - } - } - return configs, serverBlocks, lastDirectiveIndex, nil -} - -// loadConfigs reads input (named filename) and parses it, returning the -// server configurations in the order they appeared in the input. As part -// of this, it activates Let's Encrypt for the configs that are produced. -// Thus, the returned configs are already optimally configured for HTTPS. -func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { - configs, serverBlocks, lastDirectiveIndex, err := loadConfigsUpToIncludingTLS(filename, input) - if err != nil { - return nil, err - } - - // Now we have all the configs, but they have only been set up to the - // point of tls. We need to activate Let's Encrypt before setting up - // the rest of the middlewares so they have correct information regarding - // TLS configuration, if necessary. (this only appends, so our iterations - // over server blocks below shouldn't be affected) - if !IsRestart() && !Quiet { - fmt.Println("Activating privacy features...") - } - /* TODO(miek): stopped for now - configs, err = https.Activate(configs) - if err != nil { - return nil, err - } else if !IsRestart() && !Quiet { - fmt.Println(" done.") - } - */ - - // Finish setting up the rest of the directives, now that TLS is - // optimally configured. These loops are similar to above except - // we don't iterate all the directives from the beginning and we - // don't create new configs. - configIndex := -1 - for i, sb := range serverBlocks { - onces := makeOnces() - storages := makeStorages() - - for j := range sb.Addresses { - configIndex++ - - for k := lastDirectiveIndex + 1; k < len(directiveOrder); k++ { - dir := directiveOrder[k] - - if tokens, ok := sb.Tokens[dir.name]; ok { - controller := &setup.Controller{ - Config: &configs[configIndex], - Dispenser: parse.NewDispenserTokens(filename, tokens), - OncePerServerBlock: func(f func() error) error { - var err error - onces[dir.name].Do(func() { - err = f() - }) - return err - }, - ServerBlockIndex: i, - ServerBlockHostIndex: j, - ServerBlockHosts: sb.HostList(), - ServerBlockStorage: storages[dir.name], - } - midware, err := dir.setup(controller) - if err != nil { - return nil, err - } - if midware != nil { - configs[configIndex].Middleware = append(configs[configIndex].Middleware, midware) - } - storages[dir.name] = controller.ServerBlockStorage // persist for this server block - } - } - } - } - - return configs, nil -} - -// makeOnces makes a map of directive name to sync.Once -// instance. This is intended to be called once per server -// block when setting up configs so that Setup functions -// for each directive can perform a task just once per -// server block, even if there are multiple hosts on the block. -// -// We need one Once per directive, otherwise the first -// directive to use it would exclude other directives from -// using it at all, which would be a bug. -func makeOnces() map[string]*sync.Once { - onces := make(map[string]*sync.Once) - for _, dir := range directiveOrder { - onces[dir.name] = new(sync.Once) - } - return onces -} - -// makeStorages makes a map of directive name to interface{} -// so that directives' setup functions can persist state -// between different hosts on the same server block during the -// setup phase. -func makeStorages() map[string]interface{} { - storages := make(map[string]interface{}) - for _, dir := range directiveOrder { - storages[dir.name] = nil - } - return storages -} - -// arrangeBindings groups configurations by their bind address. For example, -// a server that should listen on localhost and another on 127.0.0.1 will -// be grouped into the same address: 127.0.0.1. It will return an error -// if an address is malformed or a TLS listener is configured on the -// same address as a plaintext HTTP listener. The return value is a map of -// bind address to list of configs that would become VirtualHosts on that -// server. Use the keys of the returned map to create listeners, and use -// the associated values to set up the virtualhosts. -func arrangeBindings(allConfigs []server.Config) (bindingGroup, error) { - var groupings bindingGroup - - // Group configs by bind address - for _, conf := range allConfigs { - // use default port if none is specified - if conf.Port == "" { - conf.Port = Port - } - - bindAddr, warnErr, fatalErr := resolveAddr(conf) - if fatalErr != nil { - return groupings, fatalErr - } - if warnErr != nil { - log.Printf("[WARNING] Resolving bind address for %s: %v", conf.Address(), warnErr) - } - - // Make sure to compare the string representation of the address, - // not the pointer, since a new *TCPAddr is created each time. - var existing bool - for i := 0; i < len(groupings); i++ { - if groupings[i].BindAddr.String() == bindAddr.String() { - groupings[i].Configs = append(groupings[i].Configs, conf) - existing = true - break - } - } - if !existing { - groupings = append(groupings, bindingMapping{ - BindAddr: bindAddr, - Configs: []server.Config{conf}, - }) - } - } - - // Don't allow HTTP and HTTPS to be served on the same address - for _, group := range groupings { - isTLS := group.Configs[0].TLS.Enabled - for _, config := range group.Configs { - if config.TLS.Enabled != isTLS { - thisConfigProto, otherConfigProto := "HTTP", "HTTP" - if config.TLS.Enabled { - thisConfigProto = "HTTPS" - } - if group.Configs[0].TLS.Enabled { - otherConfigProto = "HTTPS" - } - return groupings, fmt.Errorf("configuration error: Cannot multiplex %s (%s) and %s (%s) on same address", - group.Configs[0].Address(), otherConfigProto, config.Address(), thisConfigProto) - } - } - } - - return groupings, nil -} - -// resolveAddr determines the address (host and port) that a config will -// bind to. The returned address, resolvAddr, should be used to bind the -// listener or group the config with other configs using the same address. -// The first error, if not nil, is just a warning and should be reported -// but execution may continue. The second error, if not nil, is a real -// problem and the server should not be started. -// -// This function does not handle edge cases like port "http" or "https" if -// they are not known to the system. It does, however, serve on the wildcard -// host if resolving the address of the specific hostname fails. -func resolveAddr(conf server.Config) (resolvAddr *net.TCPAddr, warnErr, fatalErr error) { - resolvAddr, warnErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.BindHost, conf.Port)) - if warnErr != nil { - // the hostname probably couldn't be resolved, just bind to wildcard then - resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort("", conf.Port)) - if fatalErr != nil { - return - } - } - - return -} - -// validDirective returns true if d is a valid -// directive; false otherwise. -func validDirective(d string) bool { - for _, dir := range directiveOrder { - if dir.name == d { - return true - } - } - return false -} - -// DefaultInput returns the default Corefile input -// to use when it is otherwise empty or missing. -// It uses the default host and port and root. -func DefaultInput() CorefileInput { - port := Port - return CorefileInput{ - Contents: []byte(fmt.Sprintf("%s:%s\nroot %s", Host, port, Root)), - } -} - -// These defaults are configurable through the command line -var ( - // Root is the site root - Root = DefaultRoot - - // Host is the site host - Host = DefaultHost - - // Port is the site port - Port = DefaultPort -) - -// bindingMapping maps a network address to configurations -// that will bind to it. The order of the configs is important. -type bindingMapping struct { - BindAddr *net.TCPAddr - Configs []server.Config -} - -// bindingGroup maps network addresses to their configurations. -// Preserving the order of the groupings is important -// (related to graceful shutdown and restart) -// so this is a slice, not a literal map. -type bindingGroup []bindingMapping diff --git a/core/config_test.go b/core/config_test.go deleted file mode 100644 index c28dd4fcf..000000000 --- a/core/config_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package core - -import ( - "reflect" - "sync" - "testing" - - "github.com/miekg/coredns/server" -) - -func TestDefaultInput(t *testing.T) { - if actual, expected := string(DefaultInput().Body()), ":53\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } - - // next few tests simulate user providing -host and/or -port flags - - Host = "not-localhost.com" - if actual, expected := string(DefaultInput().Body()), "not-localhost.com:53\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } - - Host = "[::1]" - if actual, expected := string(DefaultInput().Body()), "[::1]:53\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } - - Host = "127.0.1.1" - if actual, expected := string(DefaultInput().Body()), "127.0.1.1:53\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } - - Host = "not-localhost.com" - Port = "1234" - if actual, expected := string(DefaultInput().Body()), "not-localhost.com:1234\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } - - Host = DefaultHost - Port = "1234" - if actual, expected := string(DefaultInput().Body()), ":1234\nroot ."; actual != expected { - t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) - } -} - -func TestResolveAddr(t *testing.T) { - // NOTE: If tests fail due to comparing to string "127.0.0.1", - // it's possible that system env resolves with IPv6, or ::1. - // If that happens, maybe we should use actualAddr.IP.IsLoopback() - // for the assertion, rather than a direct string comparison. - - // NOTE: Tests with {Host: "", Port: ""} and {Host: "localhost", Port: ""} - // will not behave the same cross-platform, so they have been omitted. - - for i, test := range []struct { - config server.Config - shouldWarnErr bool - shouldFatalErr bool - expectedIP string - expectedPort int - }{ - {server.Config{Host: "127.0.0.1", Port: "1234"}, false, false, "", 1234}, - {server.Config{Host: "localhost", Port: "80"}, false, false, "", 80}, - {server.Config{BindHost: "localhost", Port: "1234"}, false, false, "127.0.0.1", 1234}, - {server.Config{BindHost: "127.0.0.1", Port: "1234"}, false, false, "127.0.0.1", 1234}, - {server.Config{BindHost: "should-not-resolve", Port: "1234"}, true, false, "", 1234}, - {server.Config{BindHost: "localhost", Port: "http"}, false, false, "127.0.0.1", 80}, - {server.Config{BindHost: "localhost", Port: "https"}, false, false, "127.0.0.1", 443}, - {server.Config{BindHost: "", Port: "1234"}, false, false, "", 1234}, - {server.Config{BindHost: "localhost", Port: "abcd"}, false, true, "", 0}, - {server.Config{BindHost: "127.0.0.1", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234}, - {server.Config{BindHost: "localhost", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234}, - {server.Config{BindHost: "should-not-resolve", Host: "localhost", Port: "1234"}, true, false, "", 1234}, - } { - actualAddr, warnErr, fatalErr := resolveAddr(test.config) - - if test.shouldFatalErr && fatalErr == nil { - t.Errorf("Test %d: Expected error, but there wasn't any", i) - } - if !test.shouldFatalErr && fatalErr != nil { - t.Errorf("Test %d: Expected no error, but there was one: %v", i, fatalErr) - } - if fatalErr != nil { - continue - } - - if test.shouldWarnErr && warnErr == nil { - t.Errorf("Test %d: Expected warning, but there wasn't any", i) - } - if !test.shouldWarnErr && warnErr != nil { - t.Errorf("Test %d: Expected no warning, but there was one: %v", i, warnErr) - } - - if actual, expected := actualAddr.IP.String(), test.expectedIP; actual != expected { - t.Errorf("Test %d: IP was %s but expected %s", i, actual, expected) - } - if actual, expected := actualAddr.Port, test.expectedPort; actual != expected { - t.Errorf("Test %d: Port was %d but expected %d", i, actual, expected) - } - } -} - -func TestMakeOnces(t *testing.T) { - directives := []directive{ - {"dummy", nil}, - {"dummy2", nil}, - } - directiveOrder = directives - onces := makeOnces() - if len(onces) != len(directives) { - t.Errorf("onces had len %d , expected %d", len(onces), len(directives)) - } - expected := map[string]*sync.Once{ - "dummy": new(sync.Once), - "dummy2": new(sync.Once), - } - if !reflect.DeepEqual(onces, expected) { - t.Errorf("onces was %v, expected %v", onces, expected) - } -} - -func TestMakeStorages(t *testing.T) { - directives := []directive{ - {"dummy", nil}, - {"dummy2", nil}, - } - directiveOrder = directives - storages := makeStorages() - if len(storages) != len(directives) { - t.Errorf("storages had len %d , expected %d", len(storages), len(directives)) - } - expected := map[string]interface{}{ - "dummy": nil, - "dummy2": nil, - } - if !reflect.DeepEqual(storages, expected) { - t.Errorf("storages was %v, expected %v", storages, expected) - } -} - -func TestValidDirective(t *testing.T) { - directives := []directive{ - {"dummy", nil}, - {"dummy2", nil}, - } - directiveOrder = directives - for i, test := range []struct { - directive string - valid bool - }{ - {"dummy", true}, - {"dummy2", true}, - {"dummy3", false}, - } { - if actual, expected := validDirective(test.directive), test.valid; actual != expected { - t.Errorf("Test %d: valid was %t, expected %t", i, actual, expected) - } - } -} diff --git a/core/core.go b/core/core.go deleted file mode 100644 index 94dd06e52..000000000 --- a/core/core.go +++ /dev/null @@ -1,426 +0,0 @@ -// Package core implements the CoreDNS web server as a service -// in your own Go programs. -// -// To use this package, follow a few simple steps: -// -// 1. Set the AppName and AppVersion variables. -// 2. Call LoadCorefile() to get the Corefile (it -// might have been piped in as part of a restart). -// You should pass in your own Corefile loader. -// 3. Call core.Start() to start CoreDNS, core.Stop() -// to stop it, or core.Restart() to restart it. -// -// You should use core.Wait() to wait for all CoreDNS servers -// to quit before your process exits. -package core - -import ( - "bytes" - "encoding/gob" - "errors" - "fmt" - "io/ioutil" - "log" - "net" - "os" - "path" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/miekg/coredns/core/https" - "github.com/miekg/coredns/server" -) - -// Configurable application parameters -var ( - // AppName is the name of the application. - AppName string - - // AppVersion is the version of the application. - AppVersion string - - // Quiet when set to true, will not show any informative output on initialization. - Quiet bool - - // PidFile is the path to the pidfile to create. - PidFile string - - // GracefulTimeout is the maximum duration of a graceful shutdown. - GracefulTimeout time.Duration -) - -var ( - // corefile is the input configuration text used for this process - corefile Input - - // corefileMu protects corefile during changes - corefileMu sync.Mutex - - // errIncompleteRestart occurs if this process is a fork - // of the parent but no Corefile was piped in - errIncompleteRestart = errors.New("incomplete restart") - - // servers is a list of all the currently-listening servers - servers []*server.Server - - // serversMu protects the servers slice during changes - serversMu sync.Mutex - - // wg is used to wait for all servers to shut down - wg sync.WaitGroup - - // loadedGob is used if this is a child process as part of - // a graceful restart; it is used to map listeners to their - // index in the list of inherited file descriptors. This - // variable is not safe for concurrent access. - loadedGob corefileGob - - // startedBefore should be set to true if CoreDNS has been started - // at least once (does not indicate whether currently running). - startedBefore bool -) - -const ( - // DefaultHost is the default host. - DefaultHost = "" - // DefaultPort is the default port. - DefaultPort = "53" - // DefaultRoot is the default root folder. - DefaultRoot = "." -) - -// Start starts CoreDNS with the given Corefile. If crfile -// is nil, the LoadCorefile function will be called to get -// one. -// -// This function blocks until all the servers are listening. -// -// Note (POSIX): If Start is called in the child process of a -// restart more than once within the duration of the graceful -// cutoff (i.e. the child process called Start a first time, -// then called Stop, then Start again within the first 5 seconds -// or however long GracefulTimeout is) and the Corefiles have -// at least one listener address in common, the second Start -// may fail with "address already in use" as there's no -// guarantee that the parent process has relinquished the -// address before the grace period ends. -func Start(crfile Input) (err error) { - // If we return with no errors, we must do two things: tell the - // parent that we succeeded and write to the pidfile. - defer func() { - if err == nil { - signalSuccessToParent() // TODO: Is doing this more than once per process a bad idea? Start could get called more than once in other apps. - if PidFile != "" { - err := writePidFile() - if err != nil { - log.Printf("[ERROR] Could not write pidfile: %v", err) - } - } - } - }() - - // Input must never be nil; try to load something - if crfile == nil { - crfile, err = LoadCorefile(nil) - if err != nil { - return err - } - } - - corefileMu.Lock() - corefile = crfile - corefileMu.Unlock() - - // load the server configs (activates Let's Encrypt) - configs, err := loadConfigs(path.Base(crfile.Path()), bytes.NewReader(crfile.Body())) - if err != nil { - return err - } - - // group zones by address - groupings, err := arrangeBindings(configs) - if err != nil { - return err - } - - // Start each server with its one or more configurations - err = startServers(groupings) - if err != nil { - return err - } - startedBefore = true - - // Show initialization output - if !Quiet && !IsRestart() { - var checkedFdLimit bool - for _, group := range groupings { - for _, conf := range group.Configs { - // Print address of site - fmt.Println(conf.Address()) - - // Note if non-localhost site resolves to loopback interface - if group.BindAddr.IP.IsLoopback() && !isLocalhost(conf.Host) { - fmt.Printf("Notice: %s is only accessible on this machine (%s)\n", - conf.Host, group.BindAddr.IP.String()) - } - if !checkedFdLimit && !group.BindAddr.IP.IsLoopback() && !isLocalhost(conf.Host) { - checkFdlimit() - checkedFdLimit = true - } - } - } - } - - return nil -} - -// startServers starts all the servers in groupings, -// taking into account whether or not this process is -// a child from a graceful restart or not. It blocks -// until the servers are listening. -func startServers(groupings bindingGroup) error { - var startupWg sync.WaitGroup - errChan := make(chan error, len(groupings)) // must be buffered to allow Serve functions below to return if stopped later - - for _, group := range groupings { - s, err := server.New(group.BindAddr.String(), group.Configs, GracefulTimeout) - if err != nil { - return err - } - // TODO(miek): does not work, because this callback uses http instead of dns - // s.ReqCallback = https.RequestCallback // ensures we can solve ACME challenges while running - if s.OnDemandTLS { - s.TLSConfig.GetCertificate = https.GetOrObtainCertificate // TLS on demand -- awesome! - } else { - s.TLSConfig.GetCertificate = https.GetCertificate - } - - var ( - ln net.Listener - pc net.PacketConn - ) - - if IsRestart() { - // Look up this server's listener in the map of inherited file descriptors; if we don't have one, we must make a new one (later). - if fdIndex, ok := loadedGob.ListenerFds["tcp"+s.Addr]; ok { - file := os.NewFile(fdIndex, "") - - fln, err := net.FileListener(file) - if err != nil { - return err - } - - ln, ok = fln.(*net.TCPListener) - if !ok { - return errors.New("listener for " + s.Addr + " was not a *net.TCPListener") - } - - file.Close() - delete(loadedGob.ListenerFds, "tcp"+s.Addr) - } - if fdIndex, ok := loadedGob.ListenerFds["udp"+s.Addr]; ok { - file := os.NewFile(fdIndex, "") - - fpc, err := net.FilePacketConn(file) - if err != nil { - return err - } - - pc, ok = fpc.(*net.UDPConn) - if !ok { - return errors.New("packetConn for " + s.Addr + " was not a *net.PacketConn") - } - - file.Close() - delete(loadedGob.ListenerFds, "udp"+s.Addr) - } - } - - wg.Add(1) - go func(s *server.Server, ln net.Listener, pc net.PacketConn) { - defer wg.Done() - - // run startup functions that should only execute when the original parent process is starting. - if !IsRestart() && !startedBefore { - err := s.RunFirstStartupFuncs() - if err != nil { - errChan <- err - return - } - } - - // start the server - if ln != nil && pc != nil { - errChan <- s.Serve(ln, pc) - } else { - errChan <- s.ListenAndServe() - } - }(s, ln, pc) - - startupWg.Add(1) - go func(s *server.Server) { - defer startupWg.Done() - s.WaitUntilStarted() - }(s) - - serversMu.Lock() - servers = append(servers, s) - serversMu.Unlock() - } - - // Close the remaining (unused) file descriptors to free up resources - if IsRestart() { - for key, fdIndex := range loadedGob.ListenerFds { - os.NewFile(fdIndex, "").Close() - delete(loadedGob.ListenerFds, key) - } - } - - // Wait for all servers to finish starting - startupWg.Wait() - - // Return the first error, if any - select { - case err := <-errChan: - // "use of closed network connection" is normal if it was a graceful shutdown - if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { - return err - } - default: - } - - return nil -} - -// Stop stops all servers. It blocks until they are all stopped. -// It does NOT execute shutdown callbacks that may have been -// configured by middleware (they must be executed separately). -func Stop() error { - https.Deactivate() - - serversMu.Lock() - for _, s := range servers { - if err := s.Stop(); err != nil { - log.Printf("[ERROR] Stopping %s: %v", s.Addr, err) - } - } - servers = []*server.Server{} // don't reuse servers - serversMu.Unlock() - - return nil -} - -// Wait blocks until all servers are stopped. -func Wait() { - wg.Wait() -} - -// LoadCorefile loads a Corefile, prioritizing a Corefile -// piped from stdin as part of a restart (only happens on first call -// to LoadCorefile). If it is not a restart, this function tries -// calling the user's loader function, and if that returns nil, then -// this function resorts to the default configuration. Thus, if there -// are no other errors, this function always returns at least the -// default Corefile. -func LoadCorefile(loader func() (Input, error)) (crfile Input, err error) { - // If we are a fork, finishing the restart is highest priority; - // piped input is required in this case. - if IsRestart() { - err := gob.NewDecoder(os.Stdin).Decode(&loadedGob) - if err != nil { - return nil, err - } - crfile = loadedGob.Corefile - atomic.StoreInt32(https.OnDemandIssuedCount, loadedGob.OnDemandTLSCertsIssued) - } - - // Try user's loader - if crfile == nil && loader != nil { - crfile, err = loader() - } - - // Otherwise revert to default - if crfile == nil { - crfile = DefaultInput() - } - - return -} - -// CorefileFromPipe loads the Corefile input from f if f is -// not interactive input. f is assumed to be a pipe or stream, -// such as os.Stdin. If f is not a pipe, no error is returned -// but the Input value will be nil. An error is only returned -// if there was an error reading the pipe, even if the length -// of what was read is 0. -func CorefileFromPipe(f *os.File) (Input, error) { - fi, err := f.Stat() - if err == nil && fi.Mode()&os.ModeCharDevice == 0 { - // Note that a non-nil error is not a problem. Windows - // will not create a stdin if there is no pipe, which - // produces an error when calling Stat(). But Unix will - // make one either way, which is why we also check that - // bitmask. - // BUG: Reading from stdin after this fails (e.g. for the let's encrypt email address) (OS X) - confBody, err := ioutil.ReadAll(f) - if err != nil { - return nil, err - } - return CorefileInput{ - Contents: confBody, - Filepath: f.Name(), - }, nil - } - - // not having input from the pipe is not itself an error, - // just means no input to return. - return nil, nil -} - -// Corefile returns the current Corefile -func Corefile() Input { - corefileMu.Lock() - defer corefileMu.Unlock() - return corefile -} - -// Input represents a Corefile; its contents and file path -// (which should include the file name at the end of the path). -// If path does not apply (e.g. piped input) you may use -// any understandable value. The path is mainly used for logging, -// error messages, and debugging. -type Input interface { - // Gets the Corefile contents - Body() []byte - - // Gets the path to the origin file - Path() string - - // IsFile returns true if the original input was a file on the file system - // that could be loaded again later if requested. - IsFile() bool -} - -// TestServer returns a test server. -// The ports can be retreived with server.LocalAddr(). The testserver itself can be stopped -// with Stop(). It just takes a normal Corefile as input. -func TestServer(t *testing.T, corefile string) (*server.Server, error) { - - crfile := CorefileInput{Contents: []byte(corefile)} - configs, err := loadConfigs(path.Base(crfile.Path()), bytes.NewReader(crfile.Body())) - if err != nil { - return nil, err - } - groupings, err := arrangeBindings(configs) - if err != nil { - return nil, err - } - t.Logf("Starting %d servers", len(groupings)) - - group := groupings[0] - s, err := server.New(group.BindAddr.String(), group.Configs, time.Second) - return s, err -} diff --git a/core/coredns.go b/core/coredns.go new file mode 100644 index 000000000..e770a1a30 --- /dev/null +++ b/core/coredns.go @@ -0,0 +1,26 @@ +package core + +import ( + // plug in the server + _ "github.com/miekg/coredns/core/dnsserver" + + // plug in the standard directives + _ "github.com/miekg/coredns/middleware/bind" + _ "github.com/miekg/coredns/middleware/health" + _ "github.com/miekg/coredns/middleware/pprof" + + _ "github.com/miekg/coredns/middleware/errors" + _ "github.com/miekg/coredns/middleware/loadbalance" + _ "github.com/miekg/coredns/middleware/log" + _ "github.com/miekg/coredns/middleware/metrics" + _ "github.com/miekg/coredns/middleware/rewrite" + + _ "github.com/miekg/coredns/middleware/cache" + _ "github.com/miekg/coredns/middleware/chaos" + _ "github.com/miekg/coredns/middleware/dnssec" + _ "github.com/miekg/coredns/middleware/etcd" + _ "github.com/miekg/coredns/middleware/file" + _ "github.com/miekg/coredns/middleware/kubernetes" + _ "github.com/miekg/coredns/middleware/proxy" + _ "github.com/miekg/coredns/middleware/secondary" +) diff --git a/core/coremain/run.go b/core/coremain/run.go deleted file mode 100644 index 407beb1c6..000000000 --- a/core/coremain/run.go +++ /dev/null @@ -1,232 +0,0 @@ -package coremain - -import ( - "errors" - "flag" - "fmt" - "io/ioutil" - "log" - "os" - "runtime" - "strconv" - "strings" - "time" - - "github.com/miekg/coredns/core" - "github.com/miekg/coredns/core/https" - "github.com/xenolf/lego/acme" - "gopkg.in/natefinch/lumberjack.v2" -) - -func init() { - core.TrapSignals() - setVersion() - flag.BoolVar(&https.Agreed, "agree", false, "Agree to Let's Encrypt Subscriber Agreement") - flag.StringVar(&https.CAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "Certificate authority ACME server") - flag.StringVar(&conf, "conf", "", "Configuration file to use (default="+core.DefaultConfigFile+")") - flag.StringVar(&cpu, "cpu", "100%", "CPU cap") - flag.StringVar(&https.DefaultEmail, "email", "", "Default Let's Encrypt account email address") - flag.DurationVar(&core.GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown") - flag.StringVar(&core.Host, "host", core.DefaultHost, "Default host") - flag.StringVar(&logfile, "log", "", "Process log file") - flag.StringVar(&core.PidFile, "pidfile", "", "Path to write pid file") - flag.StringVar(&core.Port, "port", core.DefaultPort, "Default port") - flag.BoolVar(&core.Quiet, "quiet", false, "Quiet mode (no initialization output)") - flag.StringVar(&revoke, "revoke", "", "Hostname for which to revoke the certificate") - flag.StringVar(&core.Root, "root", core.DefaultRoot, "Root path to default zone files") - flag.BoolVar(&version, "version", false, "Show version") -} - -func Run() { - flag.Parse() // called here in Run() to allow other packages to set flags in their inits - - core.AppName = appName - core.AppVersion = appVersion - acme.UserAgent = appName + "/" + appVersion - - // set up process log before anything bad happens - switch logfile { - case "stdout": - log.SetOutput(os.Stdout) - case "stderr": - log.SetOutput(os.Stderr) - case "": - log.SetOutput(ioutil.Discard) - default: - log.SetOutput(&lumberjack.Logger{ - Filename: logfile, - MaxSize: 100, - MaxAge: 14, - MaxBackups: 10, - }) - } - - if revoke != "" { - err := https.Revoke(revoke) - if err != nil { - log.Fatal(err) - } - fmt.Printf("Revoked certificate for %s\n", revoke) - os.Exit(0) - } - if version { - fmt.Printf("%s %s\n", appName, appVersion) - if devBuild && gitShortStat != "" { - fmt.Printf("%s\n%s\n", gitShortStat, gitFilesModified) - } - os.Exit(0) - } - - // Set CPU cap - err := setCPU(cpu) - if err != nil { - mustLogFatal(err) - } - - // Get Corefile input - corefile, err := core.LoadCorefile(loadCorefile) - if err != nil { - mustLogFatal(err) - } - - // Start your engines - err = core.Start(corefile) - if err != nil { - mustLogFatal(err) - } - - // Twiddle your thumbs - core.Wait() -} - -// mustLogFatal just wraps log.Fatal() in a way that ensures the -// output is always printed to stderr so the user can see it -// if the user is still there, even if the process log was not -// enabled. If this process is a restart, however, and the user -// might not be there anymore, this just logs to the process log -// and exits. -func mustLogFatal(args ...interface{}) { - if !core.IsRestart() { - log.SetOutput(os.Stderr) - } - log.Fatal(args...) -} - -func loadCorefile() (core.Input, error) { - // Try -conf flag - if conf != "" { - if conf == "stdin" { - return core.CorefileFromPipe(os.Stdin) - } - - contents, err := ioutil.ReadFile(conf) - if err != nil { - return nil, err - } - - return core.CorefileInput{ - Contents: contents, - Filepath: conf, - RealFile: true, - }, nil - } - - // command line args - if flag.NArg() > 0 { - confBody := core.Host + ":" + core.Port + "\n" + strings.Join(flag.Args(), "\n") - return core.CorefileInput{ - Contents: []byte(confBody), - Filepath: "args", - }, nil - } - - // Corefile in cwd - contents, err := ioutil.ReadFile(core.DefaultConfigFile) - if err != nil { - if os.IsNotExist(err) { - return core.DefaultInput(), nil - } - return nil, err - } - return core.CorefileInput{ - Contents: contents, - Filepath: core.DefaultConfigFile, - RealFile: true, - }, nil -} - -// setCPU parses string cpu and sets GOMAXPROCS -// according to its value. It accepts either -// a number (e.g. 3) or a percent (e.g. 50%). -func setCPU(cpu string) error { - var numCPU int - - availCPU := runtime.NumCPU() - - if strings.HasSuffix(cpu, "%") { - // Percent - var percent float32 - pctStr := cpu[:len(cpu)-1] - pctInt, err := strconv.Atoi(pctStr) - if err != nil || pctInt < 1 || pctInt > 100 { - return errors.New("invalid CPU value: percentage must be between 1-100") - } - percent = float32(pctInt) / 100 - numCPU = int(float32(availCPU) * percent) - } else { - // Number - num, err := strconv.Atoi(cpu) - if err != nil || num < 1 { - return errors.New("invalid CPU value: provide a number or percent greater than 0") - } - numCPU = num - } - - if numCPU > availCPU { - numCPU = availCPU - } - - runtime.GOMAXPROCS(numCPU) - return nil -} - -// setVersion figures out the version information based on -// variables set by -ldflags. -func setVersion() { - // A development build is one that's not at a tag or has uncommitted changes - devBuild = gitTag == "" || gitShortStat != "" - - // Only set the appVersion if -ldflags was used - if gitNearestTag != "" || gitTag != "" { - if devBuild && gitNearestTag != "" { - appVersion = fmt.Sprintf("%s (+%s %s)", - strings.TrimPrefix(gitNearestTag, "v"), gitCommit, buildDate) - } else if gitTag != "" { - appVersion = strings.TrimPrefix(gitTag, "v") - } - } -} - -const appName = "CoreDNS" - -// Flags that control program flow or startup -var ( - conf string - cpu string - logfile string - revoke string - version bool -) - -// Build information obtained with the help of -ldflags -var ( - appVersion = "(untracked dev build)" // inferred at startup - devBuild = true // inferred at startup - - buildDate string // date -u - gitTag string // git describe --exact-match HEAD 2> /dev/null - gitNearestTag string // git describe --abbrev=0 --tags HEAD - gitCommit string // git rev-parse HEAD - gitShortStat string // git diff-index --shortstat - gitFilesModified string // git diff-index --name-only HEAD -) diff --git a/core/coremain/run_test.go b/core/coremain/run_test.go deleted file mode 100644 index 149dab0c1..000000000 --- a/core/coremain/run_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package coremain - -import ( - "runtime" - "testing" -) - -func TestSetCPU(t *testing.T) { - currentCPU := runtime.GOMAXPROCS(-1) - maxCPU := runtime.NumCPU() - halfCPU := int(0.5 * float32(maxCPU)) - if halfCPU < 1 { - halfCPU = 1 - } - for i, test := range []struct { - input string - output int - shouldErr bool - }{ - {"1", 1, false}, - {"-1", currentCPU, true}, - {"0", currentCPU, true}, - {"100%", maxCPU, false}, - {"50%", halfCPU, false}, - {"110%", currentCPU, true}, - {"-10%", currentCPU, true}, - {"invalid input", currentCPU, true}, - {"invalid input%", currentCPU, true}, - {"9999", maxCPU, false}, // over available CPU - } { - err := setCPU(test.input) - if test.shouldErr && err == nil { - t.Errorf("Test %d: Expected error, but there wasn't any", i) - } - if !test.shouldErr && err != nil { - t.Errorf("Test %d: Expected no error, but there was one: %v", i, err) - } - if actual, expected := runtime.GOMAXPROCS(-1), test.output; actual != expected { - t.Errorf("Test %d: GOMAXPROCS was %d but expected %d", i, actual, expected) - } - // teardown - runtime.GOMAXPROCS(currentCPU) - } -} - -func TestSetVersion(t *testing.T) { - setVersion() - if !devBuild { - t.Error("Expected default to assume development build, but it didn't") - } - if got, want := appVersion, "(untracked dev build)"; got != want { - t.Errorf("Expected appVersion='%s', got: '%s'", want, got) - } - - gitTag = "v1.1" - setVersion() - if devBuild { - t.Error("Expected a stable build if gitTag is set with no changes") - } - if got, want := appVersion, "1.1"; got != want { - t.Errorf("Expected appVersion='%s', got: '%s'", want, got) - } - - gitTag = "" - gitNearestTag = "v1.0" - gitCommit = "deadbeef" - buildDate = "Fri Feb 26 06:53:17 UTC 2016" - setVersion() - if !devBuild { - t.Error("Expected inferring a dev build when gitTag is empty") - } - if got, want := appVersion, "1.0 (+deadbeef Fri Feb 26 06:53:17 UTC 2016)"; got != want { - t.Errorf("Expected appVersion='%s', got: '%s'", want, got) - } -} diff --git a/core/directives.go b/core/directives.go deleted file mode 100644 index 63e245578..000000000 --- a/core/directives.go +++ /dev/null @@ -1,98 +0,0 @@ -package core - -import ( - "github.com/miekg/coredns/core/https" - "github.com/miekg/coredns/core/parse" - "github.com/miekg/coredns/core/setup" - "github.com/miekg/coredns/middleware" -) - -func init() { - // The parse package must know which directives - // are valid, but it must not import the setup - // or config package. To solve this problem, we - // fill up this map in our init function here. - // The parse package does not need to know the - // ordering of the directives. - for _, dir := range directiveOrder { - parse.ValidDirectives[dir.name] = struct{}{} - } -} - -// Directives are registered in the order they should be -// executed. Middleware (directives that inject a handler) -// are executed in the order A-B-C-*-C-B-A, assuming -// they all call the Next handler in the chain. -// -// Ordering is VERY important. Every middleware will -// feel the effects of all other middleware below -// (after) them during a request, but they must not -// care what middleware above them are doing. -// -// For example, log needs to know the status code and -// exactly how many bytes were written to the client, -// which every other middleware can affect, so it gets -// registered first. The errors middleware does not -// care if gzip or log modifies its response, so it -// gets registered below them. Gzip, on the other hand, -// DOES care what errors does to the response since it -// must compress every output to the client, even error -// pages, so it must be registered before the errors -// middleware and any others that would write to the -// response. -var directiveOrder = []directive{ - // Essential directives that initialize vital configuration settings - {"root", setup.Root}, - {"bind", setup.BindHost}, - {"tls", https.Setup}, - {"health", setup.Health}, - {"pprof", setup.PProf}, - - // Other directives that don't create HTTP handlers - {"startup", setup.Startup}, - {"shutdown", setup.Shutdown}, - - // Directives that inject handlers (middleware) - {"prometheus", setup.Prometheus}, - {"errors", setup.Errors}, - {"log", setup.Log}, - - {"chaos", setup.Chaos}, - {"rewrite", setup.Rewrite}, - {"loadbalance", setup.Loadbalance}, - {"cache", setup.Cache}, - {"dnssec", setup.Dnssec}, - {"file", setup.File}, - {"secondary", setup.Secondary}, - {"etcd", setup.Etcd}, - {"kubernetes", setup.Kubernetes}, - {"proxy", setup.Proxy}, -} - -// RegisterDirective adds the given directive to CoreDNS's list of directives. -// Pass the name of a directive you want it to be placed after, -// otherwise it will be placed at the bottom of the stack. -func RegisterDirective(name string, setup SetupFunc, after string) { - dir := directive{name: name, setup: setup} - idx := len(directiveOrder) - for i := range directiveOrder { - if directiveOrder[i].name == after { - idx = i + 1 - break - } - } - newDirectives := append(directiveOrder[:idx], append([]directive{dir}, directiveOrder[idx:]...)...) - directiveOrder = newDirectives - parse.ValidDirectives[name] = struct{}{} -} - -// directive ties together a directive name with its setup function. -type directive struct { - name string - setup SetupFunc -} - -// SetupFunc takes a controller and may optionally return a middleware. -// If the resulting middleware is not nil, it will be chained into -// the DNS handlers in the order specified in this package. -type SetupFunc func(c *setup.Controller) (middleware.Middleware, error) diff --git a/core/directives_test.go b/core/directives_test.go deleted file mode 100644 index 1bee144f5..000000000 --- a/core/directives_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package core - -import ( - "reflect" - "testing" -) - -func TestRegister(t *testing.T) { - directives := []directive{ - {"dummy", nil}, - {"dummy2", nil}, - } - directiveOrder = directives - RegisterDirective("foo", nil, "dummy") - if len(directiveOrder) != 3 { - t.Fatal("Should have 3 directives now") - } - getNames := func() (s []string) { - for _, d := range directiveOrder { - s = append(s, d.name) - } - return s - } - if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2"}) { - t.Fatalf("directive order doesn't match: %s", getNames()) - } - RegisterDirective("bar", nil, "ASDASD") - if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2", "bar"}) { - t.Fatalf("directive order doesn't match: %s", getNames()) - } -} diff --git a/core/dns/storage.go b/core/dns/storage.go deleted file mode 100644 index 0c8e68437..000000000 --- a/core/dns/storage.go +++ /dev/null @@ -1,28 +0,0 @@ -package dns - -import ( - "path/filepath" - - "github.com/miekg/coredns/core/assets" -) - -var storage = Storage(assets.Path()) - -// Storage is a root directory and facilitates -// forming file paths derived from it. -type Storage string - -// Zones gets the directory that stores zones data. -func (s Storage) Zones() string { - return filepath.Join(string(s), "zones") -} - -// Zone returns the path to the folder containing assets for domain. -func (s Storage) Zone(domain string) string { - return filepath.Join(s.Zones(), domain) -} - -// SecondaryZoneFile returns the path to domain's secondary zone file (when fetched). -func (s Storage) SecondaryZoneFile(domain string) string { - return filepath.Join(s.Zone(domain), "db."+domain) -} diff --git a/core/dns/storage_test.go b/core/dns/storage_test.go deleted file mode 100644 index 859435a15..000000000 --- a/core/dns/storage_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package dns - -import ( - "path/filepath" - "testing" -) - -func TestStorage(t *testing.T) { - storage = Storage("./le_test") - - if expected, actual := filepath.Join("le_test", "zones"), storage.Zones(); actual != expected { - t.Errorf("Expected Zones() to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "zones", "test.com"), storage.Zone("test.com"); actual != expected { - t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "zones", "test.com", "db.test.com"), storage.SecondaryZoneFile("test.com"); actual != expected { - t.Errorf("Expected SecondaryZoneFile() to return '%s' but got '%s'", expected, actual) - } -} diff --git a/core/dnsserver/address.go b/core/dnsserver/address.go new file mode 100644 index 000000000..865d082cc --- /dev/null +++ b/core/dnsserver/address.go @@ -0,0 +1,44 @@ +package dnsserver + +import ( + "fmt" + "net" + "strings" + + "github.com/miekg/dns" +) + +type zoneAddr struct { + Zone string + Port string +} + +// String return z.Zone + ":" + z.Port as a string. +func (z zoneAddr) String() string { return z.Zone + ":" + z.Port } + +// normalizeZone parses an zone string into a structured format with separate +// host, and port portions, as well as the original input string. +func normalizeZone(str string) (zoneAddr, error) { + var err error + + // separate host and port + host, port, err := net.SplitHostPort(str) + if err != nil { + host, port, err = net.SplitHostPort(str + ":") + // no error check here; return err at end of function + } + + if len(host) > 255 { + return zoneAddr{}, fmt.Errorf("specified zone is too long: %d > 255", len(host)) + } + _, d := dns.IsDomainName(host) + if !d { + return zoneAddr{}, fmt.Errorf("zone is not a valid domain name: %s", host) + } + + if port == "" { + port = "53" + } + + return zoneAddr{Zone: strings.ToLower(dns.Fqdn(host)), Port: port}, err +} diff --git a/core/dnsserver/config.go b/core/dnsserver/config.go new file mode 100644 index 000000000..7af483f21 --- /dev/null +++ b/core/dnsserver/config.go @@ -0,0 +1,38 @@ +package dnsserver + +import "github.com/mholt/caddy" + +// Config configuration for a single server. +type Config struct { + // The zone of the site. + Zone string + + // The hostname to bind listener to, defaults to the wildcard address + ListenHost string + + // The port to listen on. + Port string + + // The directory from which to parse db files, and store keys. + Root string + + // Middleware stack. + Middleware []Middleware + + // Compiled middleware stack. + middlewareChain Handler +} + +// GetConfig gets the Config that corresponds to c. +// If none exist nil is returned. +func GetConfig(c *caddy.Controller) *Config { + ctx := c.Context().(*dnsContext) + if cfg, ok := ctx.keysToConfigs[c.Key]; ok { + return cfg + } + // we should only get here during tests because directive + // actions typically skip the server blocks where we make + // the configs. + ctx.saveConfig(c.Key, &Config{Root: Root}) + return GetConfig(c) +} diff --git a/core/dnsserver/directives.go b/core/dnsserver/directives.go new file mode 100644 index 000000000..78a8a11f7 --- /dev/null +++ b/core/dnsserver/directives.go @@ -0,0 +1,32 @@ +package dnsserver + +// Add here, and in core/coredns.go to use them. + +// Directives are registered in the order they should be +// executed. +// +// Ordering is VERY important. Every middleware will +// feel the effects of all other middleware below +// (after) them during a request, but they must not +// care what middleware above them are doing. +var Directives = []string{ + "bind", + "health", + "pprof", + + "prometheus", + "errors", + "log", + "chaos", + "cache", + + "rewrite", + "loadbalance", + + "dnssec", + "file", + "secondary", + "etcd", + "kubernetes", + "proxy", +} diff --git a/core/dnsserver/middleware.go b/core/dnsserver/middleware.go new file mode 100644 index 000000000..5bce304b1 --- /dev/null +++ b/core/dnsserver/middleware.go @@ -0,0 +1,52 @@ +package dnsserver + +import ( + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +type ( + // Middleware is the middle layer which represents the traditional + // idea of middleware: it chains one Handler to the next by being + // passed the next Handler in the chain. + Middleware func(Handler) Handler + + // Handler is like dns.Handler except ServeDNS may return an rcode + // and/or error. + // + // If ServeDNS writes to the response body, it should return a status + // code. If the status code is not one of the following: + // * SERVFAIL (dns.RcodeServerFailure) + // * REFUSED (dns.RecodeRefused) + // * FORMERR (dns.RcodeFormatError) + // * NOTIMP (dns.RcodeNotImplemented) + // + // CoreDNS assumes *no* reply has yet been written. All other response + // codes signal other handlers above it that the response message is + // already written, and that they should not write to it also. + // + // If ServeDNS encounters an error, it should return the error value + // so it can be logged by designated error-handling middleware. + // + // If writing a response after calling another ServeDNS method, the + // returned rcode SHOULD be used when writing the response. + // + // If handling errors after calling another ServeDNS method, the + // returned error value SHOULD be logged or handled accordingly. + // + // Otherwise, return values should be propagated down the middleware + // chain by returning them unchanged. + Handler interface { + ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) + } + + // HandlerFunc is a convenience type like dns.HandlerFunc, except + // ServeDNS returns an rcode and an error. See Handler + // documentation for more information. + HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) +) + +// ServeDNS implements the Handler interface. +func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + return f(ctx, w, r) +} diff --git a/core/dnsserver/register.go b/core/dnsserver/register.go new file mode 100644 index 000000000..3c12c019c --- /dev/null +++ b/core/dnsserver/register.go @@ -0,0 +1,156 @@ +package dnsserver + +import ( + "fmt" + "net" + "time" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddyfile" +) + +const serverType = "dns" + +func init() { + caddy.RegisterServerType(serverType, caddy.ServerType{ + Directives: Directives, + DefaultInput: func() caddy.Input { + if Port == DefaultPort && Zone != "" { + return caddy.CaddyfileInput{ + Filepath: "Corefile", + Contents: nil, + ServerTypeName: serverType, + } + } + return caddy.CaddyfileInput{ + Filepath: "Corefile", + Contents: nil, + ServerTypeName: serverType, + } + }, + NewContext: newContext, + }) +} + +var TestNewContext = newContext + +func newContext() caddy.Context { + return &dnsContext{keysToConfigs: make(map[string]*Config)} +} + +type dnsContext struct { + keysToConfigs map[string]*Config + + // configs is the master list of all site configs. + configs []*Config +} + +func (h *dnsContext) saveConfig(key string, cfg *Config) { + h.configs = append(h.configs, cfg) + h.keysToConfigs[key] = cfg +} + +// InspectServerBlocks make sure that everything checks out before +// executing directives and otherwise prepares the directives to +// be parsed and executed. +func (h *dnsContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) { + // Normalize and check all the zone names and check for duplicates + dups := map[string]string{} + for _, s := range serverBlocks { + for i, k := range s.Keys { + za, err := normalizeZone(k) + if err != nil { + return nil, err + } + s.Keys[i] = za.String() + if v, ok := dups[za.Zone]; ok { + return nil, fmt.Errorf("cannot serve %s - zone already defined for %v", za, v) + + } + dups[za.Zone] = za.String() + + // Save the config to our master list, and key it for lookups + cfg := &Config{ + Zone: za.Zone, + Port: za.Port, + // TODO(miek): more? + } + h.saveConfig(za.String(), cfg) + } + } + return serverBlocks, nil +} + +// MakeServers uses the newly-created siteConfigs to create and return a list of server instances. +func (h *dnsContext) MakeServers() ([]caddy.Server, error) { + + // we must map (group) each config to a bind address + groups, err := groupConfigsByListenAddr(h.configs) + if err != nil { + return nil, err + } + // then we create a server for each group + var servers []caddy.Server + for addr, group := range groups { + s, err := NewServer(addr, group) + if err != nil { + return nil, err + } + servers = append(servers, s) + } + + return servers, nil +} + +// AddMiddleware adds a middleware to a site's middleware stack. +func (sc *Config) AddMiddleware(m Middleware) { + sc.Middleware = append(sc.Middleware, m) +} + +// groupSiteConfigsByListenAddr groups site configs by their listen +// (bind) address, so sites that use the same listener can be served +// on the same server instance. The return value maps the listen +// address (what you pass into net.Listen) to the list of site configs. +// This function does NOT vet the configs to ensure they are compatible. +func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) { + groups := make(map[string][]*Config) + + for _, conf := range configs { + if conf.Port == "" { + conf.Port = Port + } + addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.ListenHost, conf.Port)) + if err != nil { + return nil, err + } + addrstr := addr.String() + groups[addrstr] = append(groups[addrstr], conf) + } + + return groups, nil +} + +const ( + // DefaultZone is the default zone. + DefaultZone = "." + // DefaultPort is the default port. + DefaultPort = "2053" + // DefaultRoot is the default root folder. + DefaultRoot = "." +) + +// These "soft defaults" are configurable by +// command line flags, etc. +var ( + // Root is the site root + Root = DefaultRoot + + // Host is the site host + Zone = DefaultZone + + // Port is the site port + Port = DefaultPort + + // GracefulTimeout is the maximum duration of a graceful shutdown. + GracefulTimeout time.Duration +) diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go new file mode 100644 index 000000000..27c62312b --- /dev/null +++ b/core/dnsserver/server.go @@ -0,0 +1,254 @@ +package dnsserver + +import ( + "log" + "net" + "runtime" + "sync" + "time" + + "github.com/miekg/coredns/middleware" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +// Server represents an instance of a server, which serves +// DNS requests at a particular address (host and port). A +// server is capable of serving numerous zones on +// the same address and the listener may be stopped for +// graceful termination (POSIX only). +type Server struct { + Addr string // Address we listen on + mux *dns.ServeMux + server [2]*dns.Server // 0 is a net.Listener, 1 is a net.PacketConn (a *UDPConn) in our case. + + l net.Listener + p net.PacketConn + m sync.Mutex // protects listener and packetconn + + zones map[string]*Config // zones keyed by their address + dnsWg sync.WaitGroup // used to wait on outstanding connections + connTimeout time.Duration // the maximum duration of a graceful shutdown +} + +func NewServer(addr string, group []*Config) (*Server, error) { + + s := &Server{ + Addr: addr, + zones: make(map[string]*Config), + connTimeout: 5 * time.Second, // TODO(miek): was configurable + } + mux := dns.NewServeMux() + mux.Handle(".", s) // wildcard handler, everything will go through here + s.mux = mux + + // We have to bound our wg with one increment + // to prevent a "race condition" that is hard-coded + // into sync.WaitGroup.Wait() - basically, an add + // with a positive delta must be guaranteed to + // occur before Wait() is called on the wg. + // In a way, this kind of acts as a safety barrier. + s.dnsWg.Add(1) + + for _, site := range group { + // set the config per zone + s.zones[site.Zone] = site + // compile custom middleware for everything + var stack Handler + for i := len(site.Middleware) - 1; i >= 0; i-- { + stack = site.Middleware[i](stack) + } + site.middlewareChain = stack + } + + return s, nil +} + +// LocalAddr return the addresses where the server is bound to. +func (s *Server) LocalAddr() net.Addr { + s.m.Lock() + defer s.m.Unlock() + return s.l.Addr() +} + +// LocalAddrPacket return the net.PacketConn address where the server is bound to. +func (s *Server) LocalAddrPacket() net.Addr { + s.m.Lock() + defer s.m.Unlock() + return s.p.LocalAddr() +} + +// Serve starts the server with an existing listener. It blocks until the server stops. +func (s *Server) Serve(l net.Listener) error { + s.m.Lock() + s.server[tcp] = &dns.Server{Listener: l, Net: "tcp", Handler: s.mux} + s.m.Unlock() + + return s.server[tcp].ActivateAndServe() +} + +// ServePacket starts the server with an existing packetconn. It blocks until the server stops. +func (s *Server) ServePacket(p net.PacketConn) error { + s.m.Lock() + s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: s.mux} + s.m.Unlock() + + return s.server[udp].ActivateAndServe() +} + +func (s *Server) Listen() (net.Listener, error) { + l, err := net.Listen("tcp", s.Addr) + if err != nil { + return nil, err + } + s.m.Lock() + s.l = l + s.m.Unlock() + return l, nil +} + +func (s *Server) ListenPacket() (net.PacketConn, error) { + p, err := net.ListenPacket("udp", s.Addr) + if err != nil { + return nil, err + } + + s.m.Lock() + s.p = p + s.m.Unlock() + return p, nil +} + +// Stop stops the server. It blocks until the server is +// totally stopped. On POSIX systems, it will wait for +// connections to close (up to a max timeout of a few +// seconds); on Windows it will close the listener +// immediately. +func (s *Server) Stop() (err error) { + + if runtime.GOOS != "windows" { + // force connections to close after timeout + done := make(chan struct{}) + go func() { + s.dnsWg.Done() // decrement our initial increment used as a barrier + s.dnsWg.Wait() + close(done) + }() + + // Wait for remaining connections to finish or + // force them all to close after timeout + select { + case <-time.After(s.connTimeout): + case <-done: + } + } + + // Close the listener now; this stops the server without delay + s.m.Lock() + if s.l != nil { + err = s.l.Close() + } + if s.p != nil { + err = s.p.Close() + } + + for _, s1 := range s.server { + err = s1.Shutdown() + } + s.m.Unlock() + return +} + +// ServeDNS is the entry point for every request to the address that s +// is bound to. It acts as a multiplexer for the requests zonename as +// defined in the request so that the correct zone +// (configuration and middleware stack) will handle the request. +func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + // TODO(miek): expensive to use defer + defer func() { + // In case the user doesn't enable error middleware, we still + // need to make sure that we stay alive up here + if rec := recover(); rec != nil { + DefaultErrorFunc(w, r, dns.RcodeServerFailure) + } + }() + + if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once. + w.WriteMsg(m) + return + } + + q := r.Question[0].Name + b := make([]byte, len(q)) + off, end := 0, false + ctx := context.Background() + + for { + l := len(q[off:]) + for i := 0; i < l; i++ { + b[i] = q[off+i] + // normalize the name for the lookup + if b[i] >= 'A' && b[i] <= 'Z' { + b[i] |= ('a' - 'A') + } + } + + if h, ok := s.zones[string(b[:l])]; ok { + if r.Question[0].Qtype != dns.TypeDS { + rcode, _ := h.middlewareChain.ServeDNS(ctx, w, r) + if RcodeNoClientWrite(rcode) { + DefaultErrorFunc(w, r, rcode) + } + return + } + } + off, end = dns.NextLabel(q, off) + if end { + break + } + } + // Wildcard match, if we have found nothing try the root zone as a last resort. + if h, ok := s.zones["."]; ok { + rcode, _ := h.middlewareChain.ServeDNS(ctx, w, r) + if RcodeNoClientWrite(rcode) { + DefaultErrorFunc(w, r, rcode) + } + return + } + + // Still here? Error out with REFUSED and some logging + remoteHost := w.RemoteAddr().String() + DefaultErrorFunc(w, r, dns.RcodeRefused) + log.Printf("[INFO] \"%s %s %s\" - No such zone at %s (Remote: %s)", dns.Type(r.Question[0].Qtype), dns.Class(r.Question[0].Qclass), q, s.Addr, remoteHost) +} + +// DefaultErrorFunc responds to an DNS request with an error. +func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { + state := middleware.State{W: w, Req: r} + + answer := new(dns.Msg) + answer.SetRcode(r, rcode) + state.SizeAndDo(answer) + + w.WriteMsg(answer) +} + +func RcodeNoClientWrite(rcode int) bool { + switch rcode { + case dns.RcodeServerFailure: + fallthrough + case dns.RcodeRefused: + fallthrough + case dns.RcodeFormatError: + fallthrough + case dns.RcodeNotImplemented: + return true + } + return false +} + +const ( + tcp = 0 + udp = 1 +) diff --git a/core/helpers.go b/core/helpers.go deleted file mode 100644 index 82a0150af..000000000 --- a/core/helpers.go +++ /dev/null @@ -1,102 +0,0 @@ -package core - -import ( - "bytes" - "fmt" - "io/ioutil" - "log" - "os" - "os/exec" - "runtime" - "strconv" - "strings" - "sync" -) - -// isLocalhost returns true if host looks explicitly like a localhost address. -func isLocalhost(host string) bool { - return host == "localhost" || host == "::1" || strings.HasPrefix(host, "127.") -} - -// checkFdlimit issues a warning if the OS max file descriptors is below a recommended minimum. -func checkFdlimit() { - const min = 4096 - - // Warn if ulimit is too low for production sites - if runtime.GOOS == "linux" || runtime.GOOS == "darwin" { - out, err := exec.Command("sh", "-c", "ulimit -n").Output() // use sh because ulimit isn't in Linux $PATH - if err == nil { - // Note that an error here need not be reported - lim, err := strconv.Atoi(string(bytes.TrimSpace(out))) - if err == nil && lim < min { - fmt.Printf("Warning: File descriptor limit %d is too low for production sites. At least %d is recommended. Set with \"ulimit -n %d\".\n", lim, min, min) - } - } - } -} - -// signalSuccessToParent tells the parent our status using pipe at index 3. -// If this process is not a restart, this function does nothing. -// Calling this function once this process has successfully initialized -// is vital so that the parent process can unblock and kill itself. -// This function is idempotent; it executes at most once per process. -func signalSuccessToParent() { - signalParentOnce.Do(func() { - if IsRestart() { - ppipe := os.NewFile(3, "") // parent is reading from pipe at index 3 - _, err := ppipe.Write([]byte("success")) // we must send some bytes to the parent - if err != nil { - log.Printf("[ERROR] Communicating successful init to parent: %v", err) - } - ppipe.Close() - } - }) -} - -// signalParentOnce is used to make sure that the parent is only -// signaled once; doing so more than once breaks whatever socket is -// at fd 4 (the reason for this is still unclear - to reproduce, -// call Stop() and Start() in succession at least once after a -// restart, then try loading first host of Corefile in the browser). -// Do not use this directly - call signalSuccessToParent instead. -var signalParentOnce sync.Once - -// corefileGob maps bind address to index of the file descriptor -// in the Files array passed to the child process. It also contains -// the corefile contents and other state needed by the new process. -// Used only during graceful restarts where a new process is spawned. -type corefileGob struct { - ListenerFds map[string]uintptr - Corefile Input - OnDemandTLSCertsIssued int32 -} - -// IsRestart returns whether this process is, according -// to env variables, a fork as part of a graceful restart. -func IsRestart() bool { - return os.Getenv("COREDNS_RESTART") == "true" -} - -// writePidFile writes the process ID to the file at PidFile, if specified. -func writePidFile() error { - pid := []byte(strconv.Itoa(os.Getpid()) + "\n") - return ioutil.WriteFile(PidFile, pid, 0644) -} - -// CorefileInput represents a Corefile as input -// and is simply a convenient way to implement -// the Input interface. -type CorefileInput struct { - Filepath string - Contents []byte - RealFile bool -} - -// Body returns c.Contents. -func (c CorefileInput) Body() []byte { return c.Contents } - -// Path returns c.Filepath. -func (c CorefileInput) Path() string { return c.Filepath } - -// IsFile returns true if the original input was a real file on the file system. -func (c CorefileInput) IsFile() bool { return c.RealFile } diff --git a/core/https/certificates.go b/core/https/certificates.go deleted file mode 100644 index 6a8f3adc6..000000000 --- a/core/https/certificates.go +++ /dev/null @@ -1,234 +0,0 @@ -package https - -import ( - "crypto/tls" - "crypto/x509" - "errors" - "io/ioutil" - "log" - "strings" - "sync" - "time" - - "github.com/xenolf/lego/acme" - "golang.org/x/crypto/ocsp" -) - -// certCache stores certificates in memory, -// keying certificates by name. -var certCache = make(map[string]Certificate) -var certCacheMu sync.RWMutex - -// Certificate is a tls.Certificate with associated metadata tacked on. -// Even if the metadata can be obtained by parsing the certificate, -// we can be more efficient by extracting the metadata once so it's -// just there, ready to use. -type Certificate struct { - tls.Certificate - - // Names is the list of names this certificate is written for. - // The first is the CommonName (if any), the rest are SAN. - Names []string - - // NotAfter is when the certificate expires. - NotAfter time.Time - - // Managed certificates are certificates that CoreDNS is managing, - // as opposed to the user specifying a certificate and key file - // or directory and managing the certificate resources themselves. - Managed bool - - // OnDemand certificates are obtained or loaded on-demand during TLS - // handshakes (as opposed to preloaded certificates, which are loaded - // at startup). If OnDemand is true, Managed must necessarily be true. - // OnDemand certificates are maintained in the background just like - // preloaded ones, however, if an OnDemand certificate fails to renew, - // it is removed from the in-memory cache. - OnDemand bool - - // OCSP contains the certificate's parsed OCSP response. - OCSP *ocsp.Response -} - -// getCertificate gets a certificate that matches name (a server name) -// from the in-memory cache. If there is no exact match for name, it -// will be checked against names of the form '*.example.com' (wildcard -// certificates) according to RFC 6125. If a match is found, matched will -// be true. If no matches are found, matched will be false and a default -// certificate will be returned with defaulted set to true. If no default -// certificate is set, defaulted will be set to false. -// -// The logic in this function is adapted from the Go standard library, -// which is by the Go Authors. -// -// This function is safe for concurrent use. -func getCertificate(name string) (cert Certificate, matched, defaulted bool) { - var ok bool - - // Not going to trim trailing dots here since RFC 3546 says, - // "The hostname is represented ... without a trailing dot." - // Just normalize to lowercase. - name = strings.ToLower(name) - - certCacheMu.RLock() - defer certCacheMu.RUnlock() - - // exact match? great, let's use it - if cert, ok = certCache[name]; ok { - matched = true - return - } - - // try replacing labels in the name with wildcards until we get a match - labels := strings.Split(name, ".") - for i := range labels { - labels[i] = "*" - candidate := strings.Join(labels, ".") - if cert, ok = certCache[candidate]; ok { - matched = true - return - } - } - - // if nothing matches, use the default certificate or bust - cert, defaulted = certCache[""] - return -} - -// cacheManagedCertificate loads the certificate for domain into the -// cache, flagging it as Managed and, if onDemand is true, as OnDemand -// (meaning that it was obtained or loaded during a TLS handshake). -// -// This function is safe for concurrent use. -func cacheManagedCertificate(domain string, onDemand bool) (Certificate, error) { - cert, err := makeCertificateFromDisk(storage.SiteCertFile(domain), storage.SiteKeyFile(domain)) - if err != nil { - return cert, err - } - cert.Managed = true - cert.OnDemand = onDemand - cacheCertificate(cert) - return cert, nil -} - -// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile -// and keyFile, which must be in PEM format. It stores the certificate in -// memory. The Managed and OnDemand flags of the certificate will be set to -// false. -// -// This function is safe for concurrent use. -func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error { - cert, err := makeCertificateFromDisk(certFile, keyFile) - if err != nil { - return err - } - cacheCertificate(cert) - return nil -} - -// cacheUnmanagedCertificatePEMBytes makes a certificate out of the PEM bytes -// of the certificate and key, then caches it in memory. -// -// This function is safe for concurrent use. -func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error { - cert, err := makeCertificate(certBytes, keyBytes) - if err != nil { - return err - } - cacheCertificate(cert) - return nil -} - -// makeCertificateFromDisk makes a Certificate by loading the -// certificate and key files. It fills out all the fields in -// the certificate except for the Managed and OnDemand flags. -// (It is up to the caller to set those.) -func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) { - certPEMBlock, err := ioutil.ReadFile(certFile) - if err != nil { - return Certificate{}, err - } - keyPEMBlock, err := ioutil.ReadFile(keyFile) - if err != nil { - return Certificate{}, err - } - return makeCertificate(certPEMBlock, keyPEMBlock) -} - -// makeCertificate turns a certificate PEM bundle and a key PEM block into -// a Certificate, with OCSP and other relevant metadata tagged with it, -// except for the OnDemand and Managed flags. It is up to the caller to -// set those properties. -func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { - var cert Certificate - - // Convert to a tls.Certificate - tlsCert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) - if err != nil { - return cert, err - } - if len(tlsCert.Certificate) == 0 { - return cert, errors.New("certificate is empty") - } - - // Parse leaf certificate and extract relevant metadata - leaf, err := x509.ParseCertificate(tlsCert.Certificate[0]) - if err != nil { - return cert, err - } - if leaf.Subject.CommonName != "" { - cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)} - } - for _, name := range leaf.DNSNames { - if name != leaf.Subject.CommonName { - cert.Names = append(cert.Names, strings.ToLower(name)) - } - } - cert.NotAfter = leaf.NotAfter - - // Staple OCSP - ocspBytes, ocspResp, err := acme.GetOCSPForCert(certPEMBlock) - if err != nil { - // An error here is not a problem because a certificate may simply - // not contain a link to an OCSP server. But we should log it anyway. - log.Printf("[WARNING] No OCSP stapling for %v: %v", cert.Names, err) - } else if ocspResp.Status == ocsp.Good { - tlsCert.OCSPStaple = ocspBytes - cert.OCSP = ocspResp - } - - cert.Certificate = tlsCert - return cert, nil -} - -// cacheCertificate adds cert to the in-memory cache. If the cache is -// empty, cert will be used as the default certificate. If the cache is -// full, random entries are deleted until there is room to map all the -// names on the certificate. -// -// This certificate will be keyed to the names in cert.Names. Any name -// that is already a key in the cache will be replaced with this cert. -// -// This function is safe for concurrent use. -func cacheCertificate(cert Certificate) { - certCacheMu.Lock() - if _, ok := certCache[""]; !ok { - // use as default - cert.Names = append(cert.Names, "") - certCache[""] = cert - } - for len(certCache)+len(cert.Names) > 10000 { - // for simplicity, just remove random elements - for key := range certCache { - if key == "" { // ... but not the default cert - continue - } - delete(certCache, key) - break - } - } - for _, name := range cert.Names { - certCache[name] = cert - } - certCacheMu.Unlock() -} diff --git a/core/https/certificates_test.go b/core/https/certificates_test.go deleted file mode 100644 index dbfb4efc1..000000000 --- a/core/https/certificates_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package https - -import "testing" - -func TestUnexportedGetCertificate(t *testing.T) { - defer func() { certCache = make(map[string]Certificate) }() - - // When cache is empty - if _, matched, defaulted := getCertificate("example.com"); matched || defaulted { - t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted) - } - - // When cache has one certificate in it (also is default) - defaultCert := Certificate{Names: []string{"example.com", ""}} - certCache[""] = defaultCert - certCache["example.com"] = defaultCert - if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" { - t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) - } - if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" { - t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) - } - - // When retrieving wildcard certificate - certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}} - if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" { - t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) - } - - // When no certificate matches, the default is returned - if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted { - t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert) - } else if cert.Names[0] != "example.com" { - t.Errorf("Expected default cert, got: %v", cert) - } -} - -func TestCacheCertificate(t *testing.T) { - defer func() { certCache = make(map[string]Certificate) }() - - cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}}) - if _, ok := certCache["example.com"]; !ok { - t.Error("Expected first cert to be cached by key 'example.com', but it wasn't") - } - if _, ok := certCache["sub.example.com"]; !ok { - t.Error("Expected first cert to be cached by key 'sub.exmaple.com', but it wasn't") - } - if cert, ok := certCache[""]; !ok || cert.Names[2] != "" { - t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't") - } - - cacheCertificate(Certificate{Names: []string{"example2.com"}}) - if _, ok := certCache["example2.com"]; !ok { - t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't") - } - if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" { - t.Error("Expected second cert to NOT be cached as default, but it was") - } -} diff --git a/core/https/client.go b/core/https/client.go deleted file mode 100644 index e9e8cd82c..000000000 --- a/core/https/client.go +++ /dev/null @@ -1,215 +0,0 @@ -package https - -import ( - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "net" - "sync" - "time" - - "github.com/miekg/coredns/server" - "github.com/xenolf/lego/acme" -) - -// acmeMu ensures that only one ACME challenge occurs at a time. -var acmeMu sync.Mutex - -// ACMEClient is an acme.Client with custom state attached. -type ACMEClient struct { - *acme.Client - AllowPrompts bool // if false, we assume AlternatePort must be used -} - -// NewACMEClient creates a new ACMEClient given an email and whether -// prompting the user is allowed. Clients should not be kept and -// re-used over long periods of time, but immediate re-use is more -// efficient than re-creating on every iteration. -var NewACMEClient = func(email string, allowPrompts bool) (*ACMEClient, error) { - // Look up or create the LE user account - leUser, err := getUser(email) - if err != nil { - return nil, err - } - - // The client facilitates our communication with the CA server. - client, err := acme.NewClient(CAUrl, &leUser, KeyType) - if err != nil { - return nil, err - } - - // If not registered, the user must register an account with the CA - // and agree to terms - if leUser.Registration == nil { - reg, err := client.Register() - if err != nil { - return nil, errors.New("registration error: " + err.Error()) - } - leUser.Registration = reg - - if allowPrompts { // can't prompt a user who isn't there - if !Agreed && reg.TosURL == "" { - Agreed = promptUserAgreement(saURL, false) // TODO - latest URL - } - if !Agreed && reg.TosURL == "" { - return nil, errors.New("user must agree to terms") - } - } - - err = client.AgreeToTOS() - if err != nil { - saveUser(leUser) // Might as well try, right? - return nil, errors.New("error agreeing to terms: " + err.Error()) - } - - // save user to the file system - err = saveUser(leUser) - if err != nil { - return nil, errors.New("could not save user: " + err.Error()) - } - } - - return &ACMEClient{ - Client: client, - AllowPrompts: allowPrompts, - }, nil -} - -// NewACMEClientGetEmail creates a new ACMEClient and gets an email -// address at the same time (a server config is required, since it -// may contain an email address in it). -func NewACMEClientGetEmail(config server.Config, allowPrompts bool) (*ACMEClient, error) { - return NewACMEClient(getEmail(config, allowPrompts), allowPrompts) -} - -// Configure configures c according to bindHost, which is the host (not -// whole address) to bind the listener to in solving the http and tls-sni -// challenges. -func (c *ACMEClient) Configure(bindHost string) { - // If we allow prompts, operator must be present. In our case, - // that is synonymous with saying the server is not already - // started. So if the user is still there, we don't use - // AlternatePort because we don't need to proxy the challenges. - // Conversely, if the operator is not there, the server has - // already started and we need to proxy the challenge. - if c.AllowPrompts { - // Operator is present; server is not already listening - c.SetHTTPAddress(net.JoinHostPort(bindHost, "")) - c.SetTLSAddress(net.JoinHostPort(bindHost, "")) - //c.ExcludeChallenges([]acme.Challenge{acme.DNS01}) - } else { - // Operator is not present; server is started, so proxy challenges - c.SetHTTPAddress(net.JoinHostPort(bindHost, AlternatePort)) - c.SetTLSAddress(net.JoinHostPort(bindHost, AlternatePort)) - //c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) - } - c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) // TODO: can we proxy TLS challenges? and we should support DNS... -} - -// Obtain obtains a single certificate for names. It stores the certificate -// on the disk if successful. -func (c *ACMEClient) Obtain(names []string) error { -Attempts: - for attempts := 0; attempts < 2; attempts++ { - acmeMu.Lock() - certificate, failures := c.ObtainCertificate(names, true, nil) - acmeMu.Unlock() - if len(failures) > 0 { - // Error - try to fix it or report it to the user and abort - var errMsg string // we'll combine all the failures into a single error message - var promptedForAgreement bool // only prompt user for agreement at most once - - for errDomain, obtainErr := range failures { - // TODO: Double-check, will obtainErr ever be nil? - if tosErr, ok := obtainErr.(acme.TOSError); ok { - // Terms of Service agreement error; we can probably deal with this - if !Agreed && !promptedForAgreement && c.AllowPrompts { - Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL - promptedForAgreement = true - } - if Agreed || !c.AllowPrompts { - err := c.AgreeToTOS() - if err != nil { - return errors.New("error agreeing to updated terms: " + err.Error()) - } - continue Attempts - } - } - - // If user did not agree or it was any other kind of error, just append to the list of errors - errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n" - } - return errors.New(errMsg) - } - - // Success - immediately save the certificate resource - err := saveCertResource(certificate) - if err != nil { - return fmt.Errorf("error saving assets for %v: %v", names, err) - } - - break - } - - return nil -} - -// Renew renews the managed certificate for name. Right now our storage -// mechanism only supports one name per certificate, so this function only -// accepts one domain as input. It can be easily modified to support SAN -// certificates if, one day, they become desperately needed enough that our -// storage mechanism is upgraded to be more complex to support SAN certs. -// -// Anyway, this function is safe for concurrent use. -func (c *ACMEClient) Renew(name string) error { - // Prepare for renewal (load PEM cert, key, and meta) - certBytes, err := ioutil.ReadFile(storage.SiteCertFile(name)) - if err != nil { - return err - } - keyBytes, err := ioutil.ReadFile(storage.SiteKeyFile(name)) - if err != nil { - return err - } - metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(name)) - if err != nil { - return err - } - var certMeta acme.CertificateResource - err = json.Unmarshal(metaBytes, &certMeta) - certMeta.Certificate = certBytes - certMeta.PrivateKey = keyBytes - - // Perform renewal and retry if necessary, but not too many times. - var newCertMeta acme.CertificateResource - var success bool - for attempts := 0; attempts < 2; attempts++ { - acmeMu.Lock() - newCertMeta, err = c.RenewCertificate(certMeta, true) - acmeMu.Unlock() - if err == nil { - success = true - break - } - - // If the legal terms changed and need to be agreed to again, - // we can handle that. - if _, ok := err.(acme.TOSError); ok { - err := c.AgreeToTOS() - if err != nil { - return err - } - continue - } - - // For any other kind of error, wait 10s and try again. - time.Sleep(10 * time.Second) - } - - if !success { - return errors.New("too many renewal attempts; last error: " + err.Error()) - } - - return saveCertResource(newCertMeta) -} diff --git a/core/https/crypto.go b/core/https/crypto.go deleted file mode 100644 index 7971bda36..000000000 --- a/core/https/crypto.go +++ /dev/null @@ -1,57 +0,0 @@ -package https - -import ( - "crypto" - "crypto/ecdsa" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "io/ioutil" - "os" -) - -// loadPrivateKey loads a PEM-encoded ECC/RSA private key from file. -func loadPrivateKey(file string) (crypto.PrivateKey, error) { - keyBytes, err := ioutil.ReadFile(file) - if err != nil { - return nil, err - } - keyBlock, _ := pem.Decode(keyBytes) - - switch keyBlock.Type { - case "RSA PRIVATE KEY": - return x509.ParsePKCS1PrivateKey(keyBlock.Bytes) - case "EC PRIVATE KEY": - return x509.ParseECPrivateKey(keyBlock.Bytes) - } - - return nil, errors.New("unknown private key type") -} - -// savePrivateKey saves a PEM-encoded ECC/RSA private key to file. -func savePrivateKey(key crypto.PrivateKey, file string) error { - var pemType string - var keyBytes []byte - switch key := key.(type) { - case *ecdsa.PrivateKey: - var err error - pemType = "EC" - keyBytes, err = x509.MarshalECPrivateKey(key) - if err != nil { - return err - } - case *rsa.PrivateKey: - pemType = "RSA" - keyBytes = x509.MarshalPKCS1PrivateKey(key) - } - - pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes} - keyOut, err := os.Create(file) - if err != nil { - return err - } - keyOut.Chmod(0600) - defer keyOut.Close() - return pem.Encode(keyOut, &pemKey) -} diff --git a/core/https/crypto_test.go b/core/https/crypto_test.go deleted file mode 100644 index 07d2af5c7..000000000 --- a/core/https/crypto_test.go +++ /dev/null @@ -1,111 +0,0 @@ -package https - -import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "os" - "runtime" - "testing" -) - -func TestSaveAndLoadRSAPrivateKey(t *testing.T) { - keyFile := "test.key" - defer os.Remove(keyFile) - - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatal(err) - } - - // test save - err = savePrivateKey(privateKey, keyFile) - if err != nil { - t.Fatal("error saving private key:", err) - } - - // it doesn't make sense to test file permission on windows - if runtime.GOOS != "windows" { - // get info of the key file - info, err := os.Stat(keyFile) - if err != nil { - t.Fatal("error stating private key:", err) - } - // verify permission of key file is correct - if info.Mode().Perm() != 0600 { - t.Error("Expected key file to have permission 0600, but it wasn't") - } - } - - // test load - loadedKey, err := loadPrivateKey(keyFile) - if err != nil { - t.Error("error loading private key:", err) - } - - // verify loaded key is correct - if !PrivateKeysSame(privateKey, loadedKey) { - t.Error("Expected key bytes to be the same, but they weren't") - } -} - -func TestSaveAndLoadECCPrivateKey(t *testing.T) { - keyFile := "test.key" - defer os.Remove(keyFile) - - privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - if err != nil { - t.Fatal(err) - } - - // test save - err = savePrivateKey(privateKey, keyFile) - if err != nil { - t.Fatal("error saving private key:", err) - } - - // it doesn't make sense to test file permission on windows - if runtime.GOOS != "windows" { - // get info of the key file - info, err := os.Stat(keyFile) - if err != nil { - t.Fatal("error stating private key:", err) - } - // verify permission of key file is correct - if info.Mode().Perm() != 0600 { - t.Error("Expected key file to have permission 0600, but it wasn't") - } - } - - // test load - loadedKey, err := loadPrivateKey(keyFile) - if err != nil { - t.Error("error loading private key:", err) - } - - // verify loaded key is correct - if !PrivateKeysSame(privateKey, loadedKey) { - t.Error("Expected key bytes to be the same, but they weren't") - } -} - -// PrivateKeysSame compares the bytes of a and b and returns true if they are the same. -func PrivateKeysSame(a, b crypto.PrivateKey) bool { - return bytes.Equal(PrivateKeyBytes(a), PrivateKeyBytes(b)) -} - -// PrivateKeyBytes returns the bytes of DER-encoded key. -func PrivateKeyBytes(key crypto.PrivateKey) []byte { - var keyBytes []byte - switch key := key.(type) { - case *rsa.PrivateKey: - keyBytes = x509.MarshalPKCS1PrivateKey(key) - case *ecdsa.PrivateKey: - keyBytes, _ = x509.MarshalECPrivateKey(key) - } - return keyBytes -} diff --git a/core/https/handler.go b/core/https/handler.go deleted file mode 100644 index f3139f54e..000000000 --- a/core/https/handler.go +++ /dev/null @@ -1,42 +0,0 @@ -package https - -import ( - "crypto/tls" - "log" - "net/http" - "net/http/httputil" - "net/url" - "strings" -) - -const challengeBasePath = "/.well-known/acme-challenge" - -// RequestCallback proxies challenge requests to ACME client if the -// request path starts with challengeBasePath. It returns true if it -// handled the request and no more needs to be done; it returns false -// if this call was a no-op and the request still needs handling. -func RequestCallback(w http.ResponseWriter, r *http.Request) bool { - if strings.HasPrefix(r.URL.Path, challengeBasePath) { - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - - upstream, err := url.Parse(scheme + "://localhost:" + AlternatePort) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - log.Printf("[ERROR] ACME proxy handler: %v", err) - return true - } - - proxy := httputil.NewSingleHostReverseProxy(upstream) - proxy.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs - } - proxy.ServeHTTP(w, r) - - return true - } - - return false -} diff --git a/core/https/handler_test.go b/core/https/handler_test.go deleted file mode 100644 index 016799ffb..000000000 --- a/core/https/handler_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package https - -import ( - "net" - "net/http" - "net/http/httptest" - "testing" -) - -func TestRequestCallbackNoOp(t *testing.T) { - // try base paths that aren't handled by this handler - for _, url := range []string{ - "http://localhost/", - "http://localhost/foo.html", - "http://localhost/.git", - "http://localhost/.well-known/", - "http://localhost/.well-known/acme-challenging", - } { - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatalf("Could not craft request, got error: %v", err) - } - rw := httptest.NewRecorder() - if RequestCallback(rw, req) { - t.Errorf("Got true with this URL, but shouldn't have: %s", url) - } - } -} - -func TestRequestCallbackSuccess(t *testing.T) { - expectedPath := challengeBasePath + "/asdf" - - // Set up fake acme handler backend to make sure proxying succeeds - var proxySuccess bool - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxySuccess = true - if r.URL.Path != expectedPath { - t.Errorf("Expected path '%s' but got '%s' instead", expectedPath, r.URL.Path) - } - })) - - // Custom listener that uses the port we expect - ln, err := net.Listen("tcp", "127.0.0.1:"+AlternatePort) - if err != nil { - t.Fatalf("Unable to start test server listener: %v", err) - } - ts.Listener = ln - - // Start our engines and run the test - ts.Start() - defer ts.Close() - req, err := http.NewRequest("GET", "http://127.0.0.1:"+AlternatePort+expectedPath, nil) - if err != nil { - t.Fatalf("Could not craft request, got error: %v", err) - } - rw := httptest.NewRecorder() - - RequestCallback(rw, req) - - if !proxySuccess { - t.Fatal("Expected request to be proxied, but it wasn't") - } -} diff --git a/core/https/handshake.go b/core/https/handshake.go deleted file mode 100644 index a05231c49..000000000 --- a/core/https/handshake.go +++ /dev/null @@ -1,316 +0,0 @@ -package https - -import ( - "bytes" - "crypto/tls" - "encoding/pem" - "errors" - "fmt" - "log" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/miekg/coredns/server" - "github.com/xenolf/lego/acme" -) - -// GetCertificate gets a certificate to satisfy clientHello as long as -// the certificate is already cached in memory. It will not be loaded -// from disk or obtained from the CA during the handshake. -// -// This function is safe for use as a tls.Config.GetCertificate callback. -func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := getCertDuringHandshake(clientHello.ServerName, false, false) - return &cert.Certificate, err -} - -// GetOrObtainCertificate will get a certificate to satisfy clientHello, even -// if that means obtaining a new certificate from a CA during the handshake. -// It first checks the in-memory cache, then accesses disk, then accesses the -// network if it must. An obtained certificate will be stored on disk and -// cached in memory. -// -// This function is safe for use as a tls.Config.GetCertificate callback. -func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := getCertDuringHandshake(clientHello.ServerName, true, true) - return &cert.Certificate, err -} - -// getCertDuringHandshake will get a certificate for name. It first tries -// the in-memory cache. If no certificate for name is in the cache and if -// loadIfNecessary == true, it goes to disk to load it into the cache and -// serve it. If it's not on disk and if obtainIfNecessary == true, the -// certificate will be obtained from the CA, cached, and served. If -// obtainIfNecessary is true, then loadIfNecessary must also be set to true. -// An error will be returned if and only if no certificate is available. -// -// This function is safe for concurrent use. -func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { - // First check our in-memory cache to see if we've already loaded it - cert, matched, defaulted := getCertificate(name) - if matched { - return cert, nil - } - - if loadIfNecessary { - // Then check to see if we have one on disk - loadedCert, err := cacheManagedCertificate(name, true) - if err == nil { - loadedCert, err = handshakeMaintenance(name, loadedCert) - if err != nil { - log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err) - } - return loadedCert, nil - } - - if obtainIfNecessary { - // By this point, we need to ask the CA for a certificate - - name = strings.ToLower(name) - - // Make sure aren't over any applicable limits - err := checkLimitsForObtainingNewCerts(name) - if err != nil { - return Certificate{}, err - } - - // TODO(miek): deleted, tls will be enabled when a keyword is specified. - // Obtain certificate from the CA - return obtainOnDemandCertificate(name) - } - } - - if defaulted { - return cert, nil - } - - return Certificate{}, errors.New("no certificate for " + name) -} - -// checkLimitsForObtainingNewCerts checks to see if name can be issued right -// now according to mitigating factors we keep track of and preferences the -// user has set. If a non-nil error is returned, do not issue a new certificate -// for name. -func checkLimitsForObtainingNewCerts(name string) error { - // User can set hard limit for number of certs for the process to issue - if onDemandMaxIssue > 0 && atomic.LoadInt32(OnDemandIssuedCount) >= onDemandMaxIssue { - return fmt.Errorf("%s: maximum certificates issued (%d)", name, onDemandMaxIssue) - } - - // Make sure name hasn't failed a challenge recently - failedIssuanceMu.RLock() - when, ok := failedIssuance[name] - failedIssuanceMu.RUnlock() - if ok { - return fmt.Errorf("%s: throttled; refusing to issue cert since last attempt on %s failed", name, when.String()) - } - - // Make sure, if we've issued a few certificates already, that we haven't - // issued any recently - lastIssueTimeMu.Lock() - since := time.Since(lastIssueTime) - lastIssueTimeMu.Unlock() - if atomic.LoadInt32(OnDemandIssuedCount) >= 10 && since < 10*time.Minute { - return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since) - } - - // 👍Good to go - return nil -} - -// obtainOnDemandCertificate obtains a certificate for name for the given -// name. If another goroutine has already started obtaining a cert for -// name, it will wait and use what the other goroutine obtained. -// -// This function is safe for use by multiple concurrent goroutines. -func obtainOnDemandCertificate(name string) (Certificate, error) { - // We must protect this process from happening concurrently, so synchronize. - obtainCertWaitChansMu.Lock() - wait, ok := obtainCertWaitChans[name] - if ok { - // lucky us -- another goroutine is already obtaining the certificate. - // wait for it to finish obtaining the cert and then we'll use it. - obtainCertWaitChansMu.Unlock() - <-wait - return getCertDuringHandshake(name, true, false) - } - - // looks like it's up to us to do all the work and obtain the cert - wait = make(chan struct{}) - obtainCertWaitChans[name] = wait - obtainCertWaitChansMu.Unlock() - - // Unblock waiters and delete waitgroup when we return - defer func() { - obtainCertWaitChansMu.Lock() - close(wait) - delete(obtainCertWaitChans, name) - obtainCertWaitChansMu.Unlock() - }() - - log.Printf("[INFO] Obtaining new certificate for %s", name) - - // obtain cert - client, err := NewACMEClientGetEmail(server.Config{}, false) - if err != nil { - return Certificate{}, errors.New("error creating client: " + err.Error()) - } - client.Configure("") // TODO: which BindHost? - err = client.Obtain([]string{name}) - if err != nil { - // Failed to solve challenge, so don't allow another on-demand - // issue for this name to be attempted for a little while. - failedIssuanceMu.Lock() - failedIssuance[name] = time.Now() - go func(name string) { - time.Sleep(5 * time.Minute) - failedIssuanceMu.Lock() - delete(failedIssuance, name) - failedIssuanceMu.Unlock() - }(name) - failedIssuanceMu.Unlock() - return Certificate{}, err - } - - // Success - update counters and stuff - atomic.AddInt32(OnDemandIssuedCount, 1) - lastIssueTimeMu.Lock() - lastIssueTime = time.Now() - lastIssueTimeMu.Unlock() - - // The certificate is already on disk; now just start over to load it and serve it - return getCertDuringHandshake(name, true, false) -} - -// handshakeMaintenance performs a check on cert for expiration and OCSP -// validity. -// -// This function is safe for use by multiple concurrent goroutines. -func handshakeMaintenance(name string, cert Certificate) (Certificate, error) { - // Check cert expiration - timeLeft := cert.NotAfter.Sub(time.Now().UTC()) - if timeLeft < renewDurationBefore { - log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) - return renewDynamicCertificate(name) - } - - // Check OCSP staple validity - if cert.OCSP != nil { - refreshTime := cert.OCSP.ThisUpdate.Add(cert.OCSP.NextUpdate.Sub(cert.OCSP.ThisUpdate) / 2) - if time.Now().After(refreshTime) { - err := stapleOCSP(&cert, nil) - if err != nil { - // An error with OCSP stapling is not the end of the world, and in fact, is - // quite common considering not all certs have issuer URLs that support it. - log.Printf("[ERROR] Getting OCSP for %s: %v", name, err) - } - certCacheMu.Lock() - certCache[name] = cert - certCacheMu.Unlock() - } - } - - return cert, nil -} - -// renewDynamicCertificate renews currentCert using the clientHello. It returns the -// certificate to use and an error, if any. currentCert may be returned even if an -// error occurs, since we perform renewals before they expire and it may still be -// usable. name should already be lower-cased before calling this function. -// -// This function is safe for use by multiple concurrent goroutines. -func renewDynamicCertificate(name string) (Certificate, error) { - obtainCertWaitChansMu.Lock() - wait, ok := obtainCertWaitChans[name] - if ok { - // lucky us -- another goroutine is already renewing the certificate. - // wait for it to finish, then we'll use the new one. - obtainCertWaitChansMu.Unlock() - <-wait - return getCertDuringHandshake(name, true, false) - } - - // looks like it's up to us to do all the work and renew the cert - wait = make(chan struct{}) - obtainCertWaitChans[name] = wait - obtainCertWaitChansMu.Unlock() - - // unblock waiters and delete waitgroup when we return - defer func() { - obtainCertWaitChansMu.Lock() - close(wait) - delete(obtainCertWaitChans, name) - obtainCertWaitChansMu.Unlock() - }() - - log.Printf("[INFO] Renewing certificate for %s", name) - - client, err := NewACMEClientGetEmail(server.Config{}, false) - if err != nil { - return Certificate{}, err - } - client.Configure("") // TODO: Bind address of relevant listener, yuck - err = client.Renew(name) - if err != nil { - return Certificate{}, err - } - - return getCertDuringHandshake(name, true, false) -} - -// stapleOCSP staples OCSP information to cert for hostname name. -// If you have it handy, you should pass in the PEM-encoded certificate -// bundle; otherwise the DER-encoded cert will have to be PEM-encoded. -// If you don't have the PEM blocks handy, just pass in nil. -// -// Errors here are not necessarily fatal, it could just be that the -// certificate doesn't have an issuer URL. -func stapleOCSP(cert *Certificate, pemBundle []byte) error { - if pemBundle == nil { - // The function in the acme package that gets OCSP requires a PEM-encoded cert - bundle := new(bytes.Buffer) - for _, derBytes := range cert.Certificate.Certificate { - pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - } - pemBundle = bundle.Bytes() - } - - ocspBytes, ocspResp, err := acme.GetOCSPForCert(pemBundle) - if err != nil { - return err - } - - cert.Certificate.OCSPStaple = ocspBytes - cert.OCSP = ocspResp - - return nil -} - -// obtainCertWaitChans is used to coordinate obtaining certs for each hostname. -var obtainCertWaitChans = make(map[string]chan struct{}) -var obtainCertWaitChansMu sync.Mutex - -// OnDemandIssuedCount is the number of certificates that have been issued -// on-demand by this process. It is only safe to modify this count atomically. -// If it reaches onDemandMaxIssue, on-demand issuances will fail. -var OnDemandIssuedCount = new(int32) - -// onDemandMaxIssue is set based on max_certs in tls config. It specifies the -// maximum number of certificates that can be issued. -// TODO: This applies globally, but we should probably make a server-specific -// way to keep track of these limits and counts, since it's specified in the -// Corefile... -var onDemandMaxIssue int32 - -// failedIssuance is a set of names that we recently failed to get a -// certificate for from the ACME CA. They are removed after some time. -// When a name is in this map, do not issue a certificate for it on-demand. -var failedIssuance = make(map[string]time.Time) -var failedIssuanceMu sync.RWMutex - -// lastIssueTime records when we last obtained a certificate successfully. -// If this value is recent, do not make any on-demand certificate requests. -var lastIssueTime time.Time -var lastIssueTimeMu sync.Mutex diff --git a/core/https/handshake_test.go b/core/https/handshake_test.go deleted file mode 100644 index cf70eb17d..000000000 --- a/core/https/handshake_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package https - -import ( - "crypto/tls" - "crypto/x509" - "testing" -) - -func TestGetCertificate(t *testing.T) { - defer func() { certCache = make(map[string]Certificate) }() - - hello := &tls.ClientHelloInfo{ServerName: "example.com"} - helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} - helloNoSNI := &tls.ClientHelloInfo{} - helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} - - // When cache is empty - if cert, err := GetCertificate(hello); err == nil { - t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert) - } - if cert, err := GetCertificate(helloNoSNI); err == nil { - t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert) - } - - // When cache has one certificate in it (also is default) - defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}} - certCache[""] = defaultCert - certCache["example.com"] = defaultCert - if cert, err := GetCertificate(hello); err != nil { - t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err) - } else if cert.Leaf.DNSNames[0] != "example.com" { - t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert) - } - if cert, err := GetCertificate(helloNoSNI); err != nil { - t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err) - } else if cert.Leaf.DNSNames[0] != "example.com" { - t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert) - } - - // When retrieving wildcard certificate - certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}} - if cert, err := GetCertificate(helloSub); err != nil { - t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err) - } else if cert.Leaf.DNSNames[0] != "*.example.com" { - t.Errorf("Got wrong certificate, expected wildcard: %v", cert) - } - - // When no certificate matches, the default is returned - if cert, err := GetCertificate(helloNoMatch); err != nil { - t.Errorf("Expected default certificate with no error when no matches, got err: %v", err) - } else if cert.Leaf.DNSNames[0] != "example.com" { - t.Errorf("Expected default cert with no matches, got: %v", cert) - } -} diff --git a/core/https/https.go b/core/https/https.go deleted file mode 100644 index 99ef2fef6..000000000 --- a/core/https/https.go +++ /dev/null @@ -1,339 +0,0 @@ -// Package https facilitates the management of TLS assets and integrates -// Let's Encrypt functionality into CoreDNS with first-class support for -// creating and renewing certificates automatically. It is designed to -// configure sites for HTTPS by default. -package https - -import ( - "encoding/json" - "errors" - "io/ioutil" - "net" - "os" - - "github.com/miekg/coredns/server" - "github.com/xenolf/lego/acme" -) - -// Activate sets up TLS for each server config in configs -// as needed; this consists of acquiring and maintaining -// certificates and keys for qualifying configs and enabling -// OCSP stapling for all TLS-enabled configs. -// -// This function may prompt the user to provide an email -// address if none is available through other means. It -// prefers the email address specified in the config, but -// if that is not available it will check the command line -// argument. If absent, it will use the most recent email -// address from last time. If there isn't one, the user -// will be prompted and shown SA link. -// -// Also note that calling this function activates asset -// management automatically, which keeps certificates -// renewed and OCSP stapling updated. -// -// Activate returns the updated list of configs, since -// some may have been appended, for example, to redirect -// plaintext HTTP requests to their HTTPS counterpart. -// This function only appends; it does not splice. -func Activate(configs []server.Config) ([]server.Config, error) { - // just in case previous caller forgot... - Deactivate() - - // pre-screen each config and earmark the ones that qualify for managed TLS - MarkQualified(configs) - - // place certificates and keys on disk - err := ObtainCerts(configs, true, false) - if err != nil { - return configs, err - } - - // update TLS configurations - err = EnableTLS(configs, true) - if err != nil { - return configs, err - } - - // renew all relevant certificates that need renewal. this is important - // to do right away for a couple reasons, mainly because each restart, - // the renewal ticker is reset, so if restarts happen more often than - // the ticker interval, renewals would never happen. but doing - // it right away at start guarantees that renewals aren't missed. - err = renewManagedCertificates(true) - if err != nil { - return configs, err - } - - // keep certificates renewed and OCSP stapling updated - go maintainAssets(stopChan) - - return configs, nil -} - -// Deactivate cleans up long-term, in-memory resources -// allocated by calling Activate(). Essentially, it stops -// the asset maintainer from running, meaning that certificates -// will not be renewed, OCSP staples will not be updated, etc. -func Deactivate() (err error) { - defer func() { - if rec := recover(); rec != nil { - err = errors.New("already deactivated") - } - }() - close(stopChan) - stopChan = make(chan struct{}) - return -} - -// MarkQualified scans each config and, if it qualifies for managed -// TLS, it sets the Managed field of the TLSConfig to true. -func MarkQualified(configs []server.Config) { - for i := 0; i < len(configs); i++ { - if ConfigQualifies(configs[i]) { - configs[i].TLS.Managed = true - } - } -} - -// ObtainCerts obtains certificates for all these configs as long as a -// certificate does not already exist on disk. It does not modify the -// configs at all; it only obtains and stores certificates and keys to -// the disk. If allowPrompts is true, the user may be shown a prompt. -// If proxyACME is true, the ACME challenges will be proxied to our alt port. -func ObtainCerts(configs []server.Config, allowPrompts, proxyACME bool) error { - // We group configs by email so we don't make the same clients over and - // over. This has the potential to prompt the user for an email, but we - // prevent that by assuming that if we already have a listener that can - // proxy ACME challenge requests, then the server is already running and - // the operator is no longer present. - groupedConfigs := groupConfigsByEmail(configs, allowPrompts) - - for email, group := range groupedConfigs { - // Wait as long as we can before creating the client, because it - // may not be needed, for example, if we already have what we - // need on disk. Creating a client involves the network and - // potentially prompting the user, etc., so only do if necessary. - var client *ACMEClient - - for _, cfg := range group { - if existingCertAndKey(cfg.Host) { - continue - } - - // Now we definitely do need a client - if client == nil { - var err error - client, err = NewACMEClient(email, allowPrompts) - if err != nil { - return errors.New("error creating client: " + err.Error()) - } - } - - // c.Configure assumes that allowPrompts == !proxyACME, - // but that's not always true. For example, a restart where - // the user isn't present and we're not listening on port 80. - // TODO: This could probably be refactored better. - if proxyACME { - client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, AlternatePort)) - client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, AlternatePort)) - client.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) - } else { - client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, "")) - client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, "")) - client.ExcludeChallenges([]acme.Challenge{acme.DNS01}) - } - - err := client.Obtain([]string{cfg.Host}) - if err != nil { - return err - } - } - } - - return nil -} - -// groupConfigsByEmail groups configs by the email address to be used by an -// ACME client. It only groups configs that have TLS enabled and that are -// marked as Managed. If userPresent is true, the operator MAY be prompted -// for an email address. -func groupConfigsByEmail(configs []server.Config, userPresent bool) map[string][]server.Config { - initMap := make(map[string][]server.Config) - for _, cfg := range configs { - if !cfg.TLS.Managed { - continue - } - leEmail := getEmail(cfg, userPresent) - initMap[leEmail] = append(initMap[leEmail], cfg) - } - return initMap -} - -// EnableTLS configures each config to use TLS according to default settings. -// It will only change configs that are marked as managed, and assumes that -// certificates and keys are already on disk. If loadCertificates is true, -// the certificates will be loaded from disk into the cache for this process -// to use. If false, TLS will still be enabled and configured with default -// settings, but no certificates will be parsed loaded into the cache, and -// the returned error value will always be nil. -func EnableTLS(configs []server.Config, loadCertificates bool) error { - for i := 0; i < len(configs); i++ { - if !configs[i].TLS.Managed { - continue - } - configs[i].TLS.Enabled = true - if loadCertificates { - _, err := cacheManagedCertificate(configs[i].Host, false) - if err != nil { - return err - } - } - setDefaultTLSParams(&configs[i]) - } - return nil -} - -// hostHasOtherPort returns true if there is another config in the list with the same -// hostname that has port otherPort, or false otherwise. All the configs are checked -// against the hostname of allConfigs[thisConfigIdx]. -func hostHasOtherPort(allConfigs []server.Config, thisConfigIdx int, otherPort string) bool { - for i, otherCfg := range allConfigs { - if i == thisConfigIdx { - continue // has to be a config OTHER than the one we're comparing against - } - if otherCfg.Host == allConfigs[thisConfigIdx].Host && otherCfg.Port == otherPort { - return true - } - } - return false -} - -// ConfigQualifies returns true if cfg qualifies for -// fully managed TLS (but not on-demand TLS, which is -// not considered here). It does NOT check to see if a -// cert and key already exist for the config. If the -// config does qualify, you should set cfg.TLS.Managed -// to true and check that instead, because the process of -// setting up the config may make it look like it -// doesn't qualify even though it originally did. -func ConfigQualifies(cfg server.Config) bool { - return (!cfg.TLS.Manual || cfg.TLS.OnDemand) && // user might provide own cert and key - - // user can force-disable automatic HTTPS for this host - cfg.Port != "80" && - cfg.TLS.LetsEncryptEmail != "off" && - - // we get can't certs for some kinds of hostnames, but - // on-demand TLS allows empty hostnames at startup - cfg.TLS.OnDemand -} - -// existingCertAndKey returns true if the host has a certificate -// and private key in storage already, false otherwise. -func existingCertAndKey(host string) bool { - _, err := os.Stat(storage.SiteCertFile(host)) - if err != nil { - return false - } - _, err = os.Stat(storage.SiteKeyFile(host)) - if err != nil { - return false - } - return true -} - -// saveCertResource saves the certificate resource to disk. This -// includes the certificate file itself, the private key, and the -// metadata file. -func saveCertResource(cert acme.CertificateResource) error { - err := os.MkdirAll(storage.Site(cert.Domain), 0700) - if err != nil { - return err - } - - // Save cert - err = ioutil.WriteFile(storage.SiteCertFile(cert.Domain), cert.Certificate, 0600) - if err != nil { - return err - } - - // Save private key - err = ioutil.WriteFile(storage.SiteKeyFile(cert.Domain), cert.PrivateKey, 0600) - if err != nil { - return err - } - - // Save cert metadata - jsonBytes, err := json.MarshalIndent(&cert, "", "\t") - if err != nil { - return err - } - err = ioutil.WriteFile(storage.SiteMetaFile(cert.Domain), jsonBytes, 0600) - if err != nil { - return err - } - - return nil -} - -// Revoke revokes the certificate for host via ACME protocol. -func Revoke(host string) error { - if !existingCertAndKey(host) { - return errors.New("no certificate and key for " + host) - } - - email := getEmail(server.Config{Host: host}, true) - if email == "" { - return errors.New("email is required to revoke") - } - - client, err := NewACMEClient(email, true) - if err != nil { - return err - } - - certFile := storage.SiteCertFile(host) - certBytes, err := ioutil.ReadFile(certFile) - if err != nil { - return err - } - - err = client.RevokeCertificate(certBytes) - if err != nil { - return err - } - - err = os.Remove(certFile) - if err != nil { - return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error()) - } - - return nil -} - -var ( - // DefaultEmail represents the Let's Encrypt account email to use if none provided - DefaultEmail string - - // Agreed indicates whether user has agreed to the Let's Encrypt SA - Agreed bool - - // CAUrl represents the base URL to the CA's ACME endpoint - CAUrl string -) - -// AlternatePort is the port on which the acme client will open a -// listener and solve the CA's challenges. If this alternate port -// is used instead of the default port (80 or 443), then the -// default port for the challenge must be forwarded to this one. -const AlternatePort = "5033" - -// KeyType is the type to use for new keys. -// This shouldn't need to change except for in tests; -// the size can be drastically reduced for speed. -var KeyType = acme.EC384 - -// stopChan is used to signal the maintenance goroutine -// to terminate. -var stopChan chan struct{} diff --git a/core/https/https_test.go b/core/https/https_test.go deleted file mode 100644 index f19b3cde0..000000000 --- a/core/https/https_test.go +++ /dev/null @@ -1,323 +0,0 @@ -package https - -/* -func TestHostQualifies(t *testing.T) { - for i, test := range []struct { - host string - expect bool - }{ - {"localhost", false}, - {"127.0.0.1", false}, - {"127.0.1.5", false}, - {"::1", false}, - {"[::1]", false}, - {"[::]", false}, - {"::", false}, - {"", false}, - {" ", false}, - {"0.0.0.0", false}, - {"192.168.1.3", false}, - {"10.0.2.1", false}, - {"169.112.53.4", false}, - {"foobar.com", true}, - {"sub.foobar.com", true}, - } { - if HostQualifies(test.host) && !test.expect { - t.Errorf("Test %d: Expected '%s' to NOT qualify, but it did", i, test.host) - } - if !HostQualifies(test.host) && test.expect { - t.Errorf("Test %d: Expected '%s' to qualify, but it did NOT", i, test.host) - } - } -} - -func TestConfigQualifies(t *testing.T) { - for i, test := range []struct { - cfg server.Config - expect bool - }{ - {server.Config{Host: ""}, false}, - {server.Config{Host: "localhost"}, false}, - {server.Config{Host: "123.44.3.21"}, false}, - {server.Config{Host: "example.com"}, true}, - {server.Config{Host: "example.com", TLS: server.TLSConfig{Manual: true}}, false}, - {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false}, - {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true}, - {server.Config{Host: "example.com", Scheme: "http"}, false}, - {server.Config{Host: "example.com", Port: "80"}, false}, - {server.Config{Host: "example.com", Port: "1234"}, true}, - {server.Config{Host: "example.com", Scheme: "https"}, true}, - {server.Config{Host: "example.com", Port: "80", Scheme: "https"}, false}, - } { - if test.expect && !ConfigQualifies(test.cfg) { - t.Errorf("Test %d: Expected config to qualify, but it did NOT: %#v", i, test.cfg) - } - if !test.expect && ConfigQualifies(test.cfg) { - t.Errorf("Test %d: Expected config to NOT qualify, but it did: %#v", i, test.cfg) - } - } -} - -func TestRedirPlaintextHost(t *testing.T) { - cfg := redirPlaintextHost(server.Config{ - Host: "example.com", - BindHost: "93.184.216.34", - Port: "1234", - }) - - // Check host and port - if actual, expected := cfg.Host, "example.com"; actual != expected { - t.Errorf("Expected redir config to have host %s but got %s", expected, actual) - } - if actual, expected := cfg.BindHost, "93.184.216.34"; actual != expected { - t.Errorf("Expected redir config to have bindhost %s but got %s", expected, actual) - } - if actual, expected := cfg.Port, "80"; actual != expected { - t.Errorf("Expected redir config to have port '%s' but got '%s'", expected, actual) - } - - // Make sure redirect handler is set up properly - if cfg.Middleware == nil || len(cfg.Middleware) != 1 { - t.Fatalf("Redir config middleware not set up properly; got: %#v", cfg.Middleware) - } - - handler, ok := cfg.Middleware[0](nil).(redirect.Redirect) - if !ok { - t.Fatalf("Expected a redirect.Redirect middleware, but got: %#v", handler) - } - if len(handler.Rules) != 1 { - t.Fatalf("Expected one redirect rule, got: %#v", handler.Rules) - } - - // Check redirect rule for correctness - if actual, expected := handler.Rules[0].FromScheme, "http"; actual != expected { - t.Errorf("Expected redirect rule to be from scheme '%s' but is actually from '%s'", expected, actual) - } - if actual, expected := handler.Rules[0].FromPath, "/"; actual != expected { - t.Errorf("Expected redirect rule to be for path '%s' but is actually for '%s'", expected, actual) - } - if actual, expected := handler.Rules[0].To, "https://{host}:1234{uri}"; actual != expected { - t.Errorf("Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual) - } - if actual, expected := handler.Rules[0].Code, http.StatusMovedPermanently; actual != expected { - t.Errorf("Expected redirect rule to have code %d but was %d", expected, actual) - } - - // browsers can infer a default port from scheme, so make sure the port - // doesn't get added in explicitly for default ports like 443 for https. - cfg = redirPlaintextHost(server.Config{Host: "example.com", Port: "443"}) - handler, ok = cfg.Middleware[0](nil).(redirect.Redirect) - if actual, expected := handler.Rules[0].To, "https://{host}{uri}"; actual != expected { - t.Errorf("(Default Port) Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual) - } -} - -func TestSaveCertResource(t *testing.T) { - storage = Storage("./le_test_save") - defer func() { - err := os.RemoveAll(string(storage)) - if err != nil { - t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err) - } - }() - - domain := "example.com" - certContents := "certificate" - keyContents := "private key" - metaContents := `{ - "domain": "example.com", - "certUrl": "https://example.com/cert", - "certStableUrl": "https://example.com/cert/stable" -}` - - cert := acme.CertificateResource{ - Domain: domain, - CertURL: "https://example.com/cert", - CertStableURL: "https://example.com/cert/stable", - PrivateKey: []byte(keyContents), - Certificate: []byte(certContents), - } - - err := saveCertResource(cert) - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - certFile, err := ioutil.ReadFile(storage.SiteCertFile(domain)) - if err != nil { - t.Errorf("Expected no error reading certificate file, got: %v", err) - } - if string(certFile) != certContents { - t.Errorf("Expected certificate file to contain '%s', got '%s'", certContents, string(certFile)) - } - - keyFile, err := ioutil.ReadFile(storage.SiteKeyFile(domain)) - if err != nil { - t.Errorf("Expected no error reading private key file, got: %v", err) - } - if string(keyFile) != keyContents { - t.Errorf("Expected private key file to contain '%s', got '%s'", keyContents, string(keyFile)) - } - - metaFile, err := ioutil.ReadFile(storage.SiteMetaFile(domain)) - if err != nil { - t.Errorf("Expected no error reading meta file, got: %v", err) - } - if string(metaFile) != metaContents { - t.Errorf("Expected meta file to contain '%s', got '%s'", metaContents, string(metaFile)) - } -} - -func TestExistingCertAndKey(t *testing.T) { - storage = Storage("./le_test_existing") - defer func() { - err := os.RemoveAll(string(storage)) - if err != nil { - t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err) - } - }() - - domain := "example.com" - - if existingCertAndKey(domain) { - t.Errorf("Did NOT expect %v to have existing cert or key, but it did", domain) - } - - err := saveCertResource(acme.CertificateResource{ - Domain: domain, - PrivateKey: []byte("key"), - Certificate: []byte("cert"), - }) - if err != nil { - t.Fatalf("Expected no error, got: %v", err) - } - - if !existingCertAndKey(domain) { - t.Errorf("Expected %v to have existing cert and key, but it did NOT", domain) - } -} - -func TestHostHasOtherPort(t *testing.T) { - configs := []server.Config{ - {Host: "example.com", Port: "80"}, - {Host: "sub1.example.com", Port: "80"}, - {Host: "sub1.example.com", Port: "443"}, - } - - if hostHasOtherPort(configs, 0, "80") { - t.Errorf(`Expected hostHasOtherPort(configs, 0, "80") to be false, but got true`) - } - if hostHasOtherPort(configs, 0, "443") { - t.Errorf(`Expected hostHasOtherPort(configs, 0, "443") to be false, but got true`) - } - if !hostHasOtherPort(configs, 1, "443") { - t.Errorf(`Expected hostHasOtherPort(configs, 1, "443") to be true, but got false`) - } -} - -func TestMakePlaintextRedirects(t *testing.T) { - configs := []server.Config{ - // Happy path = standard redirect from 80 to 443 - {Host: "example.com", TLS: server.TLSConfig{Managed: true}}, - - // Host on port 80 already defined; don't change it (no redirect) - {Host: "sub1.example.com", Port: "80", Scheme: "http"}, - {Host: "sub1.example.com", TLS: server.TLSConfig{Managed: true}}, - - // Redirect from port 80 to port 5000 in this case - {Host: "sub2.example.com", Port: "5000", TLS: server.TLSConfig{Managed: true}}, - - // Can redirect from 80 to either 443 or 5001, but choose 443 - {Host: "sub3.example.com", Port: "443", TLS: server.TLSConfig{Managed: true}}, - {Host: "sub3.example.com", Port: "5001", Scheme: "https", TLS: server.TLSConfig{Managed: true}}, - } - - result := MakePlaintextRedirects(configs) - expectedRedirCount := 3 - - if len(result) != len(configs)+expectedRedirCount { - t.Errorf("Expected %d redirect(s) to be added, but got %d", - expectedRedirCount, len(result)-len(configs)) - } -} - -func TestEnableTLS(t *testing.T) { - configs := []server.Config{ - {Host: "example.com", TLS: server.TLSConfig{Managed: true}}, - {}, // not managed - no changes! - } - - EnableTLS(configs, false) - - if !configs[0].TLS.Enabled { - t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false") - } - if configs[1].TLS.Enabled { - t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true") - } -} - -func TestGroupConfigsByEmail(t *testing.T) { - if groupConfigsByEmail([]server.Config{}, false) == nil { - t.Errorf("With empty input, returned map was nil, but expected non-nil map") - } - - configs := []server.Config{ - {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, - {Host: "sub1.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}}, - {Host: "sub2.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, - {Host: "sub3.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}}, - {Host: "sub4.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, - {Host: "sub5.example.com", TLS: server.TLSConfig{LetsEncryptEmail: ""}}, // not managed - } - DefaultEmail = "test@example.com" - - groups := groupConfigsByEmail(configs, true) - - if groups == nil { - t.Fatalf("Returned map was nil, but expected values") - } - - if len(groups) != 2 { - t.Errorf("Expected 2 groups, got %d: %#v", len(groups), groups) - } - if len(groups["foo@bar"]) != 2 { - t.Errorf("Expected 2 configs for foo@bar, got %d: %#v", len(groups["foobar"]), groups["foobar"]) - } - if len(groups[DefaultEmail]) != 3 { - t.Errorf("Expected 3 configs for %s, got %d: %#v", DefaultEmail, len(groups["foobar"]), groups["foobar"]) - } -} - -func TestMarkQualified(t *testing.T) { - // TODO: TestConfigQualifies and this test share the same config list... - configs := []server.Config{ - {Host: ""}, - {Host: "localhost"}, - {Host: "123.44.3.21"}, - {Host: "example.com"}, - {Host: "example.com", TLS: server.TLSConfig{Manual: true}}, - {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, - {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, - {Host: "example.com", Scheme: "http"}, - {Host: "example.com", Port: "80"}, - {Host: "example.com", Port: "1234"}, - {Host: "example.com", Scheme: "https"}, - {Host: "example.com", Port: "80", Scheme: "https"}, - } - expectedManagedCount := 4 - - MarkQualified(configs) - - count := 0 - for _, cfg := range configs { - if cfg.TLS.Managed { - count++ - } - } - - if count != expectedManagedCount { - t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count) - } -} -*/ diff --git a/core/https/maintain.go b/core/https/maintain.go deleted file mode 100644 index 46fd3d1f5..000000000 --- a/core/https/maintain.go +++ /dev/null @@ -1,211 +0,0 @@ -package https - -import ( - "log" - "time" - - "github.com/miekg/coredns/server" - - "golang.org/x/crypto/ocsp" -) - -const ( - // RenewInterval is how often to check certificates for renewal. - RenewInterval = 12 * time.Hour - - // OCSPInterval is how often to check if OCSP stapling needs updating. - OCSPInterval = 1 * time.Hour -) - -// maintainAssets is a permanently-blocking function -// that loops indefinitely and, on a regular schedule, checks -// certificates for expiration and initiates a renewal of certs -// that are expiring soon. It also updates OCSP stapling and -// performs other maintenance of assets. -// -// You must pass in the channel which you'll close when -// maintenance should stop, to allow this goroutine to clean up -// after itself and unblock. -func maintainAssets(stopChan chan struct{}) { - renewalTicker := time.NewTicker(RenewInterval) - ocspTicker := time.NewTicker(OCSPInterval) - - for { - select { - case <-renewalTicker.C: - log.Println("[INFO] Scanning for expiring certificates") - renewManagedCertificates(false) - log.Println("[INFO] Done checking certificates") - case <-ocspTicker.C: - log.Println("[INFO] Scanning for stale OCSP staples") - updateOCSPStaples() - log.Println("[INFO] Done checking OCSP staples") - case <-stopChan: - renewalTicker.Stop() - ocspTicker.Stop() - log.Println("[INFO] Stopped background maintenance routine") - return - } - } -} - -func renewManagedCertificates(allowPrompts bool) (err error) { - var renewed, deleted []Certificate - var client *ACMEClient - visitedNames := make(map[string]struct{}) - - certCacheMu.RLock() - for name, cert := range certCache { - if !cert.Managed { - continue - } - - // the list of names on this cert should never be empty... - if cert.Names == nil || len(cert.Names) == 0 { - log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v", name, cert.Names) - deleted = append(deleted, cert) - continue - } - - // skip names whose certificate we've already renewed - if _, ok := visitedNames[name]; ok { - continue - } - for _, name := range cert.Names { - visitedNames[name] = struct{}{} - } - - timeLeft := cert.NotAfter.Sub(time.Now().UTC()) - if timeLeft < renewDurationBefore { - log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) - - if client == nil { - client, err = NewACMEClientGetEmail(server.Config{}, allowPrompts) - if err != nil { - return err - } - client.Configure("") // TODO: Bind address of relevant listener, yuck - } - - err := client.Renew(cert.Names[0]) // managed certs better have only one name - if err != nil { - if client.AllowPrompts && timeLeft < 0 { - // Certificate renewal failed, the operator is present, and the certificate - // is already expired; we should stop immediately and return the error. Note - // that we used to do this any time a renewal failed at startup. However, - // after discussion in https://github.com/miekg/coredns/issues/642 we decided to - // only stop startup if the certificate is expired. We still log the error - // otherwise. - certCacheMu.RUnlock() - return err - } - log.Printf("[ERROR] %v", err) - if cert.OnDemand { - deleted = append(deleted, cert) - } - } else { - renewed = append(renewed, cert) - } - } - } - certCacheMu.RUnlock() - - // Apply changes to the cache - for _, cert := range renewed { - _, err := cacheManagedCertificate(cert.Names[0], cert.OnDemand) - if err != nil { - if client.AllowPrompts { - return err // operator is present, so report error immediately - } - log.Printf("[ERROR] %v", err) - } - } - for _, cert := range deleted { - certCacheMu.Lock() - for _, name := range cert.Names { - delete(certCache, name) - } - certCacheMu.Unlock() - } - - return nil -} - -func updateOCSPStaples() { - // Create a temporary place to store updates - // until we release the potentially long-lived - // read lock and use a short-lived write lock. - type ocspUpdate struct { - rawBytes []byte - parsed *ocsp.Response - } - updated := make(map[string]ocspUpdate) - - // A single SAN certificate maps to multiple names, so we use this - // set to make sure we don't waste cycles checking OCSP for the same - // certificate multiple times. - visited := make(map[string]struct{}) - - certCacheMu.RLock() - for name, cert := range certCache { - // skip this certificate if we've already visited it, - // and if not, mark all the names as visited - if _, ok := visited[name]; ok { - continue - } - for _, n := range cert.Names { - visited[n] = struct{}{} - } - - // no point in updating OCSP for expired certificates - if time.Now().After(cert.NotAfter) { - continue - } - - var lastNextUpdate time.Time - if cert.OCSP != nil { - // start checking OCSP staple about halfway through validity period for good measure - lastNextUpdate = cert.OCSP.NextUpdate - refreshTime := cert.OCSP.ThisUpdate.Add(lastNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2) - - // since OCSP is already stapled, we need only check if we're in that "refresh window" - if time.Now().Before(refreshTime) { - continue - } - } - - err := stapleOCSP(&cert, nil) - if err != nil { - if cert.OCSP != nil { - // if it was no staple before, that's fine, otherwise we should log the error - log.Printf("[ERROR] Checking OCSP for %s: %v", name, err) - } - continue - } - - // By this point, we've obtained the latest OCSP response. - // If there was no staple before, or if the response is updated, make - // sure we apply the update to all names on the certificate. - if lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate { - log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s", - cert.Names, lastNextUpdate, cert.OCSP.NextUpdate) - for _, n := range cert.Names { - updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP} - } - } - } - certCacheMu.RUnlock() - - // This write lock should be brief since we have all the info we need now. - certCacheMu.Lock() - for name, update := range updated { - cert := certCache[name] - cert.OCSP = update.parsed - cert.Certificate.OCSPStaple = update.rawBytes - certCache[name] = cert - } - certCacheMu.Unlock() -} - -// renewDurationBefore is how long before expiration to renew certificates. -const renewDurationBefore = (24 * time.Hour) * 30 diff --git a/core/https/setup.go b/core/https/setup.go deleted file mode 100644 index ec90e0284..000000000 --- a/core/https/setup.go +++ /dev/null @@ -1,321 +0,0 @@ -package https - -import ( - "bytes" - "crypto/tls" - "encoding/pem" - "io/ioutil" - "log" - "os" - "path/filepath" - "strconv" - "strings" - - "github.com/miekg/coredns/core/setup" - "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/server" -) - -// Setup sets up the TLS configuration and installs certificates that -// are specified by the user in the config file. All the automatic HTTPS -// stuff comes later outside of this function. -func Setup(c *setup.Controller) (middleware.Middleware, error) { - if c.Port == "80" { - c.TLS.Enabled = false - log.Printf("[WARNING] TLS disabled for %s.", c.Address()) - return nil, nil - } - c.TLS.Enabled = true - - // TODO(miek): disabled for now - return nil, nil - - for c.Next() { - var certificateFile, keyFile, loadDir, maxCerts string - - args := c.RemainingArgs() - switch len(args) { - case 1: - c.TLS.LetsEncryptEmail = args[0] - - // user can force-disable managed TLS this way - if c.TLS.LetsEncryptEmail == "off" { - c.TLS.Enabled = false - return nil, nil - } - case 2: - certificateFile = args[0] - keyFile = args[1] - c.TLS.Manual = true - } - - // Optional block with extra parameters - var hadBlock bool - for c.NextBlock() { - hadBlock = true - switch c.Val() { - case "protocols": - args := c.RemainingArgs() - if len(args) != 2 { - return nil, c.ArgErr() - } - value, ok := supportedProtocols[strings.ToLower(args[0])] - if !ok { - return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val()) - } - c.TLS.ProtocolMinVersion = value - value, ok = supportedProtocols[strings.ToLower(args[1])] - if !ok { - return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val()) - } - c.TLS.ProtocolMaxVersion = value - case "ciphers": - for c.NextArg() { - value, ok := supportedCiphersMap[strings.ToUpper(c.Val())] - if !ok { - return nil, c.Errf("Wrong cipher name or cipher not supported '%s'", c.Val()) - } - c.TLS.Ciphers = append(c.TLS.Ciphers, value) - } - case "clients": - c.TLS.ClientCerts = c.RemainingArgs() - if len(c.TLS.ClientCerts) == 0 { - return nil, c.ArgErr() - } - case "load": - c.Args(&loadDir) - c.TLS.Manual = true - case "max_certs": - c.Args(&maxCerts) - c.TLS.OnDemand = true - default: - return nil, c.Errf("Unknown keyword '%s'", c.Val()) - } - } - - // tls requires at least one argument if a block is not opened - if len(args) == 0 && !hadBlock { - return nil, c.ArgErr() - } - - // set certificate limit if on-demand TLS is enabled - if maxCerts != "" { - maxCertsNum, err := strconv.Atoi(maxCerts) - if err != nil || maxCertsNum < 1 { - return nil, c.Err("max_certs must be a positive integer") - } - if onDemandMaxIssue == 0 || int32(maxCertsNum) < onDemandMaxIssue { // keep the minimum; TODO: We have to do this because it is global; should be per-server or per-vhost... - onDemandMaxIssue = int32(maxCertsNum) - } - } - - // don't try to load certificates unless we're supposed to - if !c.TLS.Enabled || !c.TLS.Manual { - continue - } - - // load a single certificate and key, if specified - if certificateFile != "" && keyFile != "" { - err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile) - if err != nil { - return nil, c.Errf("Unable to load certificate and key files for %s: %v", c.Host, err) - } - log.Printf("[INFO] Successfully loaded TLS assets from %s and %s", certificateFile, keyFile) - } - - // load a directory of certificates, if specified - if loadDir != "" { - err := loadCertsInDir(c, loadDir) - if err != nil { - return nil, err - } - } - } - - setDefaultTLSParams(c.Config) - - return nil, nil -} - -// loadCertsInDir loads all the certificates/keys in dir, as long as -// the file ends with .pem. This method of loading certificates is -// modeled after haproxy, which expects the certificate and key to -// be bundled into the same file: -// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt -// -// This function may write to the log as it walks the directory tree. -func loadCertsInDir(c *setup.Controller, dir string) error { - return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { - if err != nil { - log.Printf("[WARNING] Unable to traverse into %s; skipping", path) - return nil - } - if info.IsDir() { - return nil - } - if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") { - certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer) - var foundKey bool // use only the first key in the file - - bundle, err := ioutil.ReadFile(path) - if err != nil { - return err - } - - for { - // Decode next block so we can see what type it is - var derBlock *pem.Block - derBlock, bundle = pem.Decode(bundle) - if derBlock == nil { - break - } - - if derBlock.Type == "CERTIFICATE" { - // Re-encode certificate as PEM, appending to certificate chain - pem.Encode(certBuilder, derBlock) - } else if derBlock.Type == "EC PARAMETERS" { - // EC keys generated from openssl can be composed of two blocks: - // parameters and key (parameter block should come first) - if !foundKey { - // Encode parameters - pem.Encode(keyBuilder, derBlock) - - // Key must immediately follow - derBlock, bundle = pem.Decode(bundle) - if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" { - return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path) - } - pem.Encode(keyBuilder, derBlock) - foundKey = true - } - } else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") { - // RSA key - if !foundKey { - pem.Encode(keyBuilder, derBlock) - foundKey = true - } - } else { - return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type) - } - } - - certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes() - if len(certPEMBytes) == 0 { - return c.Errf("%s: failed to parse PEM data", path) - } - if len(keyPEMBytes) == 0 { - return c.Errf("%s: no private key block found", path) - } - - err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes) - if err != nil { - return c.Errf("%s: failed to load cert and key for %s: %v", path, c.Host, err) - } - log.Printf("[INFO] Successfully loaded TLS assets from %s", path) - } - return nil - }) -} - -// setDefaultTLSParams sets the default TLS cipher suites, protocol versions, -// and server preferences of a server.Config if they were not previously set -// (it does not overwrite; only fills in missing values). It will also set the -// port to 443 if not already set, TLS is enabled, TLS is manual, and the host -// does not equal localhost. -func setDefaultTLSParams(c *server.Config) { - // If no ciphers provided, use default list - if len(c.TLS.Ciphers) == 0 { - c.TLS.Ciphers = defaultCiphers - } - - // Not a cipher suite, but still important for mitigating protocol downgrade attacks - // (prepend since having it at end breaks http2 due to non-h2-approved suites before it) - c.TLS.Ciphers = append([]uint16{tls.TLS_FALLBACK_SCSV}, c.TLS.Ciphers...) - - // Set default protocol min and max versions - must balance compatibility and security - if c.TLS.ProtocolMinVersion == 0 { - c.TLS.ProtocolMinVersion = tls.VersionTLS10 - } - if c.TLS.ProtocolMaxVersion == 0 { - c.TLS.ProtocolMaxVersion = tls.VersionTLS12 - } - - // Prefer server cipher suites - c.TLS.PreferServerCipherSuites = true - - // Default TLS port is 443; only use if port is not manually specified, - // TLS is enabled, and the host is not localhost - if c.Port == "" && c.TLS.Enabled && (!c.TLS.Manual || c.TLS.OnDemand) && c.Host != "localhost" { - c.Port = "443" - } -} - -// Map of supported protocols. -// SSLv3 will be not supported in future release. -// HTTP/2 only supports TLS 1.2 and higher. -var supportedProtocols = map[string]uint16{ - "ssl3.0": tls.VersionSSL30, - "tls1.0": tls.VersionTLS10, - "tls1.1": tls.VersionTLS11, - "tls1.2": tls.VersionTLS12, -} - -// Map of supported ciphers, used only for parsing config. -// -// Note that, at time of writing, HTTP/2 blacklists 276 cipher suites, -// including all but two of the suites below (the two GCM suites). -// See https://http2.github.io/http2-spec/#BadCipherSuites -// -// TLS_FALLBACK_SCSV is not in this list because we manually ensure -// it is always added (even though it is not technically a cipher suite). -// -// This map, like any map, is NOT ORDERED. Do not range over this map. -var supportedCiphersMap = map[string]uint16{ - "ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - "ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - "ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - "ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - "ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - "ECDHE-RSA-AES256-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - "ECDHE-ECDSA-AES256-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - "ECDHE-ECDSA-AES128-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - "RSA-AES128-CBC-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, - "RSA-AES256-CBC-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, - "ECDHE-RSA-3DES-EDE-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - "RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, -} - -// List of supported cipher suites in descending order of preference. -// Ordering is very important! Getting the wrong order will break -// mainstream clients, especially with HTTP/2. -// -// Note that TLS_FALLBACK_SCSV is not in this list since it is always -// added manually. -var supportedCiphers = []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, -} - -// List of all the ciphers we want to use by default -var defaultCiphers = []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, -} diff --git a/core/https/setup_test.go b/core/https/setup_test.go deleted file mode 100644 index 7640eb524..000000000 --- a/core/https/setup_test.go +++ /dev/null @@ -1,226 +0,0 @@ -package https - -// TODO(miek): all fail - -/* -func TestMain(m *testing.M) { - // Write test certificates to disk before tests, and clean up - // when we're done. - err := ioutil.WriteFile(certFile, testCert, 0644) - if err != nil { - log.Fatal(err) - } - err = ioutil.WriteFile(keyFile, testKey, 0644) - if err != nil { - os.Remove(certFile) - log.Fatal(err) - } - - result := m.Run() - - os.Remove(certFile) - os.Remove(keyFile) - os.Exit(result) -} - -func TestSetupParseBasic(t *testing.T) { - c := setup.NewTestController(`tls ` + certFile + ` ` + keyFile + ``) - - _, err := Setup(c) - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - // Basic checks - if !c.TLS.Manual { - t.Error("Expected TLS Manual=true, but was false") - } - if !c.TLS.Enabled { - t.Error("Expected TLS Enabled=true, but was false") - } - - // Security defaults - if c.TLS.ProtocolMinVersion != tls.VersionTLS10 { - t.Errorf("Expected 'tls1.0 (0x0301)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion) - } - if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 { - t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", c.TLS.ProtocolMaxVersion) - } - - // Cipher checks - expectedCiphers := []uint16{ - tls.TLS_FALLBACK_SCSV, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, - } - - // Ensure count is correct (plus one for TLS_FALLBACK_SCSV) - if len(c.TLS.Ciphers) != len(expectedCiphers) { - t.Errorf("Expected %v Ciphers (including TLS_FALLBACK_SCSV), got %v", - len(expectedCiphers), len(c.TLS.Ciphers)) - } - - // Ensure ordering is correct - for i, actual := range c.TLS.Ciphers { - if actual != expectedCiphers[i] { - t.Errorf("Expected cipher in position %d to be %0x, got %0x", i, expectedCiphers[i], actual) - } - } - - if !c.TLS.PreferServerCipherSuites { - t.Error("Expected PreferServerCipherSuites = true, but was false") - } -} - -func TestSetupParseIncompleteParams(t *testing.T) { - // Using tls without args is an error because it's unnecessary. - c := setup.NewTestController(`tls`) - _, err := Setup(c) - if err == nil { - t.Error("Expected an error, but didn't get one") - } -} - -func TestSetupParseWithOptionalParams(t *testing.T) { - params := `tls ` + certFile + ` ` + keyFile + ` { - protocols ssl3.0 tls1.2 - ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384 - }` - c := setup.NewTestController(params) - - _, err := Setup(c) - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if c.TLS.ProtocolMinVersion != tls.VersionSSL30 { - t.Errorf("Expected 'ssl3.0 (0x0300)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion) - } - - if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 { - t.Errorf("Expected 'tls1.2 (0x0302)' as ProtocolMaxVersion, got %#v", c.TLS.ProtocolMaxVersion) - } - - if len(c.TLS.Ciphers)-1 != 3 { - t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1) - } -} - -func TestSetupDefaultWithOptionalParams(t *testing.T) { - params := `tls { - ciphers RSA-3DES-EDE-CBC-SHA - }` - c := setup.NewTestController(params) - - _, err := Setup(c) - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - if len(c.TLS.Ciphers)-1 != 1 { - t.Errorf("Expected 1 ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1) - } -} - -// TODO: If we allow this... but probably not a good idea. -// func TestSetupDisableHTTPRedirect(t *testing.T) { -// c := NewTestController(`tls { -// allow_http -// }`) -// _, err := TLS(c) -// if err != nil { -// t.Errorf("Expected no error, but got %v", err) -// } -// if !c.TLS.DisableHTTPRedir { -// t.Error("Expected HTTP redirect to be disabled, but it wasn't") -// } -// } - -func TestSetupParseWithWrongOptionalParams(t *testing.T) { - // Test protocols wrong params - params := `tls ` + certFile + ` ` + keyFile + ` { - protocols ssl tls - }` - c := setup.NewTestController(params) - _, err := Setup(c) - if err == nil { - t.Errorf("Expected errors, but no error returned") - } - - // Test ciphers wrong params - params = `tls ` + certFile + ` ` + keyFile + ` { - ciphers not-valid-cipher - }` - c = setup.NewTestController(params) - _, err = Setup(c) - if err == nil { - t.Errorf("Expected errors, but no error returned") - } -} - -func TestSetupParseWithClientAuth(t *testing.T) { - params := `tls ` + certFile + ` ` + keyFile + ` { - clients client_ca.crt client2_ca.crt - }` - c := setup.NewTestController(params) - _, err := Setup(c) - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if count := len(c.TLS.ClientCerts); count != 2 { - t.Fatalf("Expected two client certs, had %d", count) - } - if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" { - t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual) - } - if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" { - t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual) - } - - // Test missing client cert file - params = `tls ` + certFile + ` ` + keyFile + ` { - clients - }` - c = setup.NewTestController(params) - _, err = Setup(c) - if err == nil { - t.Errorf("Expected an error, but no error returned") - } -} - -const ( - certFile = "test_cert.pem" - keyFile = "test_key.pem" -) - -var testCert = []byte(`-----BEGIN CERTIFICATE----- -MIIBkjCCATmgAwIBAgIJANfFCBcABL6LMAkGByqGSM49BAEwFDESMBAGA1UEAxMJ -bG9jYWxob3N0MB4XDTE2MDIxMDIyMjAyNFoXDTE4MDIwOTIyMjAyNFowFDESMBAG -A1UEAxMJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEs22MtnG7 -9K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDLSiVQvFZ6lUszTlczNxVk -pEfqrM6xAupB7qN1MHMwHQYDVR0OBBYEFHxYDvAxUwL4XrjPev6qZ/BiLDs5MEQG -A1UdIwQ9MDuAFHxYDvAxUwL4XrjPev6qZ/BiLDs5oRikFjAUMRIwEAYDVQQDEwls -b2NhbGhvc3SCCQDXxQgXAAS+izAMBgNVHRMEBTADAQH/MAkGByqGSM49BAEDSAAw -RQIgRvBqbyJM2JCJqhA1FmcoZjeMocmhxQHTt1c+1N2wFUgCIQDtvrivbBPA688N -Qh3sMeAKNKPsx5NxYdoWuu9KWcKz9A== ------END CERTIFICATE----- -`) - -var testKey = []byte(`-----BEGIN EC PARAMETERS----- -BggqhkjOPQMBBw== ------END EC PARAMETERS----- ------BEGIN EC PRIVATE KEY----- -MHcCAQEEIGLtRmwzYVcrH3J0BnzYbGPdWVF10i9p6mxkA4+b2fURoAoGCCqGSM49 -AwEHoUQDQgAEs22MtnG79K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDL -SiVQvFZ6lUszTlczNxVkpEfqrM6xAupB7g== ------END EC PRIVATE KEY----- -`) -*/ diff --git a/core/https/storage.go b/core/https/storage.go deleted file mode 100644 index 5d8e949da..000000000 --- a/core/https/storage.go +++ /dev/null @@ -1,94 +0,0 @@ -package https - -import ( - "path/filepath" - "strings" - - "github.com/miekg/coredns/core/assets" -) - -// storage is used to get file paths in a consistent, -// cross-platform way for persisting Let's Encrypt assets -// on the file system. -var storage = Storage(filepath.Join(assets.Path(), "letsencrypt")) - -// Storage is a root directory and facilitates -// forming file paths derived from it. -type Storage string - -// Sites gets the directory that stores site certificate and keys. -func (s Storage) Sites() string { - return filepath.Join(string(s), "sites") -} - -// Site returns the path to the folder containing assets for domain. -func (s Storage) Site(domain string) string { - return filepath.Join(s.Sites(), domain) -} - -// SiteCertFile returns the path to the certificate file for domain. -func (s Storage) SiteCertFile(domain string) string { - return filepath.Join(s.Site(domain), domain+".crt") -} - -// SiteKeyFile returns the path to domain's private key file. -func (s Storage) SiteKeyFile(domain string) string { - return filepath.Join(s.Site(domain), domain+".key") -} - -// SiteMetaFile returns the path to the domain's asset metadata file. -func (s Storage) SiteMetaFile(domain string) string { - return filepath.Join(s.Site(domain), domain+".json") -} - -// Users gets the directory that stores account folders. -func (s Storage) Users() string { - return filepath.Join(string(s), "users") -} - -// User gets the account folder for the user with email. -func (s Storage) User(email string) string { - if email == "" { - email = emptyEmail - } - return filepath.Join(s.Users(), email) -} - -// UserRegFile gets the path to the registration file for -// the user with the given email address. -func (s Storage) UserRegFile(email string) string { - if email == "" { - email = emptyEmail - } - fileName := emailUsername(email) - if fileName == "" { - fileName = "registration" - } - return filepath.Join(s.User(email), fileName+".json") -} - -// UserKeyFile gets the path to the private key file for -// the user with the given email address. -func (s Storage) UserKeyFile(email string) string { - if email == "" { - email = emptyEmail - } - fileName := emailUsername(email) - if fileName == "" { - fileName = "private" - } - return filepath.Join(s.User(email), fileName+".key") -} - -// emailUsername returns the username portion of an -// email address (part before '@') or the original -// input if it can't find the "@" symbol. -func emailUsername(email string) string { - at := strings.Index(email, "@") - if at == -1 { - return email - } else if at == 0 { - return email[1:] - } - return email[:at] -} diff --git a/core/https/storage_test.go b/core/https/storage_test.go deleted file mode 100644 index 85c2220eb..000000000 --- a/core/https/storage_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package https - -import ( - "path/filepath" - "testing" -) - -func TestStorage(t *testing.T) { - storage = Storage("./le_test") - - if expected, actual := filepath.Join("le_test", "sites"), storage.Sites(); actual != expected { - t.Errorf("Expected Sites() to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "sites", "test.com"), storage.Site("test.com"); actual != expected { - t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("test.com"); actual != expected { - t.Errorf("Expected SiteCertFile() to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.key"), storage.SiteKeyFile("test.com"); actual != expected { - t.Errorf("Expected SiteKeyFile() to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("test.com"); actual != expected { - t.Errorf("Expected SiteMetaFile() to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "users"), storage.Users(); actual != expected { - t.Errorf("Expected Users() to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "users", "me@example.com"), storage.User("me@example.com"); actual != expected { - t.Errorf("Expected User() to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.json"), storage.UserRegFile("me@example.com"); actual != expected { - t.Errorf("Expected UserRegFile() to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.key"), storage.UserKeyFile("me@example.com"); actual != expected { - t.Errorf("Expected UserKeyFile() to return '%s' but got '%s'", expected, actual) - } - - // Test with empty emails - if expected, actual := filepath.Join("le_test", "users", emptyEmail), storage.User(emptyEmail); actual != expected { - t.Errorf("Expected User(\"\") to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".json"), storage.UserRegFile(""); actual != expected { - t.Errorf("Expected UserRegFile(\"\") to return '%s' but got '%s'", expected, actual) - } - if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".key"), storage.UserKeyFile(""); actual != expected { - t.Errorf("Expected UserKeyFile(\"\") to return '%s' but got '%s'", expected, actual) - } -} - -func TestEmailUsername(t *testing.T) { - for i, test := range []struct { - input, expect string - }{ - { - input: "username@example.com", - expect: "username", - }, - { - input: "plus+addressing@example.com", - expect: "plus+addressing", - }, - { - input: "me+plus-addressing@example.com", - expect: "me+plus-addressing", - }, - { - input: "not-an-email", - expect: "not-an-email", - }, - { - input: "@foobar.com", - expect: "foobar.com", - }, - { - input: emptyEmail, - expect: emptyEmail, - }, - { - input: "", - expect: "", - }, - } { - if actual := emailUsername(test.input); actual != test.expect { - t.Errorf("Test %d: Expected username to be '%s' but was '%s'", i, test.expect, actual) - } - } -} diff --git a/core/https/user.go b/core/https/user.go deleted file mode 100644 index 9c30c656c..000000000 --- a/core/https/user.go +++ /dev/null @@ -1,200 +0,0 @@ -package https - -import ( - "bufio" - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "encoding/json" - "errors" - "fmt" - "io" - "io/ioutil" - "os" - "strings" - - "github.com/miekg/coredns/server" - "github.com/xenolf/lego/acme" -) - -// User represents a Let's Encrypt user account. -type User struct { - Email string - Registration *acme.RegistrationResource - key crypto.PrivateKey -} - -// GetEmail gets u's email. -func (u User) GetEmail() string { - return u.Email -} - -// GetRegistration gets u's registration resource. -func (u User) GetRegistration() *acme.RegistrationResource { - return u.Registration -} - -// GetPrivateKey gets u's private key. -func (u User) GetPrivateKey() crypto.PrivateKey { - return u.key -} - -// getUser loads the user with the given email from disk. -// If the user does not exist, it will create a new one, -// but it does NOT save new users to the disk or register -// them via ACME. It does NOT prompt the user. -func getUser(email string) (User, error) { - var user User - - // open user file - regFile, err := os.Open(storage.UserRegFile(email)) - if err != nil { - if os.IsNotExist(err) { - // create a new user - return newUser(email) - } - return user, err - } - defer regFile.Close() - - // load user information - err = json.NewDecoder(regFile).Decode(&user) - if err != nil { - return user, err - } - - // load their private key - user.key, err = loadPrivateKey(storage.UserKeyFile(email)) - if err != nil { - return user, err - } - - return user, nil -} - -// saveUser persists a user's key and account registration -// to the file system. It does NOT register the user via ACME -// or prompt the user. -func saveUser(user User) error { - // make user account folder - err := os.MkdirAll(storage.User(user.Email), 0700) - if err != nil { - return err - } - - // save private key file - err = savePrivateKey(user.key, storage.UserKeyFile(user.Email)) - if err != nil { - return err - } - - // save registration file - jsonBytes, err := json.MarshalIndent(&user, "", "\t") - if err != nil { - return err - } - - return ioutil.WriteFile(storage.UserRegFile(user.Email), jsonBytes, 0600) -} - -// newUser creates a new User for the given email address -// with a new private key. This function does NOT save the -// user to disk or register it via ACME. If you want to use -// a user account that might already exist, call getUser -// instead. It does NOT prompt the user. -func newUser(email string) (User, error) { - user := User{Email: email} - privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) - if err != nil { - return user, errors.New("error generating private key: " + err.Error()) - } - user.key = privateKey - return user, nil -} - -// getEmail does everything it can to obtain an email -// address from the user to use for TLS for cfg. If it -// cannot get an email address, it returns empty string. -// (It will warn the user of the consequences of an -// empty email.) This function MAY prompt the user for -// input. If userPresent is false, the operator will -// NOT be prompted and an empty email may be returned. -func getEmail(cfg server.Config, userPresent bool) string { - // First try the tls directive from the Corefile - leEmail := cfg.TLS.LetsEncryptEmail - if leEmail == "" { - // Then try memory (command line flag or typed by user previously) - leEmail = DefaultEmail - } - if leEmail == "" { - // Then try to get most recent user email ~/.coredns/users file - userDirs, err := ioutil.ReadDir(storage.Users()) - if err == nil { - var mostRecent os.FileInfo - for _, dir := range userDirs { - if !dir.IsDir() { - continue - } - if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) { - leEmail = dir.Name() - DefaultEmail = leEmail // save for next time - } - } - } - } - if leEmail == "" && userPresent { - // Alas, we must bother the user and ask for an email address; - // if they proceed they also agree to the SA. - reader := bufio.NewReader(stdin) - fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.") - fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:") - fmt.Println(" " + saURL) // TODO: Show current SA link - fmt.Println("Please enter your email address so you can recover your account if needed.") - fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.") - fmt.Print("Email address: ") - var err error - leEmail, err = reader.ReadString('\n') - if err != nil { - return "" - } - leEmail = strings.TrimSpace(leEmail) - DefaultEmail = leEmail - Agreed = true - } - return leEmail -} - -// promptUserAgreement prompts the user to agree to the agreement -// at agreementURL via stdin. If the agreement has changed, then pass -// true as the second argument. If this is the user's first time -// agreeing, pass false. It returns whether the user agreed or not. -func promptUserAgreement(agreementURL string, changed bool) bool { - if changed { - fmt.Printf("The Let's Encrypt Subscriber Agreement has changed:\n %s\n", agreementURL) - fmt.Print("Do you agree to the new terms? (y/n): ") - } else { - fmt.Printf("To continue, you must agree to the Let's Encrypt Subscriber Agreement:\n %s\n", agreementURL) - fmt.Print("Do you agree to the terms? (y/n): ") - } - - reader := bufio.NewReader(stdin) - answer, err := reader.ReadString('\n') - if err != nil { - return false - } - answer = strings.ToLower(strings.TrimSpace(answer)) - - return answer == "y" || answer == "yes" -} - -// stdin is used to read the user's input if prompted; -// this is changed by tests during tests. -var stdin = io.ReadWriter(os.Stdin) - -// The name of the folder for accounts where the email -// address was not provided; default 'username' if you will. -const emptyEmail = "default" - -// TODO: Use latest -const saURL = "https://letsencrypt.org/documents/LE-SA-v1.0.1-July-27-2015.pdf" diff --git a/core/https/user_test.go b/core/https/user_test.go deleted file mode 100644 index 3e1af5007..000000000 --- a/core/https/user_test.go +++ /dev/null @@ -1,196 +0,0 @@ -package https - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "io" - "os" - "strings" - "testing" - "time" - - "github.com/miekg/coredns/server" - "github.com/xenolf/lego/acme" -) - -func TestUser(t *testing.T) { - privateKey, err := rsa.GenerateKey(rand.Reader, 128) - if err != nil { - t.Fatalf("Could not generate test private key: %v", err) - } - u := User{ - Email: "me@mine.com", - Registration: new(acme.RegistrationResource), - key: privateKey, - } - - if expected, actual := "me@mine.com", u.GetEmail(); actual != expected { - t.Errorf("Expected email '%s' but got '%s'", expected, actual) - } - if u.GetRegistration() == nil { - t.Error("Expected a registration resource, but got nil") - } - if expected, actual := privateKey, u.GetPrivateKey(); actual != expected { - t.Errorf("Expected the private key at address %p but got one at %p instead ", expected, actual) - } -} - -func TestNewUser(t *testing.T) { - email := "me@foobar.com" - user, err := newUser(email) - if err != nil { - t.Fatalf("Error creating user: %v", err) - } - if user.key == nil { - t.Error("Private key is nil") - } - if user.Email != email { - t.Errorf("Expected email to be %s, but was %s", email, user.Email) - } - if user.Registration != nil { - t.Error("New user already has a registration resource; it shouldn't") - } -} - -func TestSaveUser(t *testing.T) { - storage = Storage("./testdata") - defer os.RemoveAll(string(storage)) - - email := "me@foobar.com" - user, err := newUser(email) - if err != nil { - t.Fatalf("Error creating user: %v", err) - } - - err = saveUser(user) - if err != nil { - t.Fatalf("Error saving user: %v", err) - } - _, err = os.Stat(storage.UserRegFile(email)) - if err != nil { - t.Errorf("Cannot access user registration file, error: %v", err) - } - _, err = os.Stat(storage.UserKeyFile(email)) - if err != nil { - t.Errorf("Cannot access user private key file, error: %v", err) - } -} - -func TestGetUserDoesNotAlreadyExist(t *testing.T) { - storage = Storage("./testdata") - defer os.RemoveAll(string(storage)) - - user, err := getUser("user_does_not_exist@foobar.com") - if err != nil { - t.Fatalf("Error getting user: %v", err) - } - - if user.key == nil { - t.Error("Expected user to have a private key, but it was nil") - } -} - -func TestGetUserAlreadyExists(t *testing.T) { - storage = Storage("./testdata") - defer os.RemoveAll(string(storage)) - - email := "me@foobar.com" - - // Set up test - user, err := newUser(email) - if err != nil { - t.Fatalf("Error creating user: %v", err) - } - err = saveUser(user) - if err != nil { - t.Fatalf("Error saving user: %v", err) - } - - // Expect to load user from disk - user2, err := getUser(email) - if err != nil { - t.Fatalf("Error getting user: %v", err) - } - - // Assert keys are the same - if !PrivateKeysSame(user.key, user2.key) { - t.Error("Expected private key to be the same after loading, but it wasn't") - } - - // Assert emails are the same - if user.Email != user2.Email { - t.Errorf("Expected emails to be equal, but was '%s' before and '%s' after loading", user.Email, user2.Email) - } -} - -func TestGetEmail(t *testing.T) { - // let's not clutter up the output - origStdout := os.Stdout - os.Stdout = nil - defer func() { os.Stdout = origStdout }() - - storage = Storage("./testdata") - defer os.RemoveAll(string(storage)) - DefaultEmail = "test2@foo.com" - - // Test1: Use email in config - config := server.Config{ - TLS: server.TLSConfig{ - LetsEncryptEmail: "test1@foo.com", - }, - } - actual := getEmail(config, true) - if actual != "test1@foo.com" { - t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", "test1@foo.com", actual) - } - - // Test2: Use default email from flag (or user previously typing it) - actual = getEmail(server.Config{}, true) - if actual != DefaultEmail { - t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", DefaultEmail, actual) - } - - // Test3: Get input from user - DefaultEmail = "" - stdin = new(bytes.Buffer) - _, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n")) - if err != nil { - t.Fatalf("Could not simulate user input, error: %v", err) - } - actual = getEmail(server.Config{}, true) - if actual != "test3@foo.com" { - t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual) - } - - // Test4: Get most recent email from before - DefaultEmail = "" - for i, eml := range []string{ - "test4-3@foo.com", - "test4-2@foo.com", - "test4-1@foo.com", - } { - u, err := newUser(eml) - if err != nil { - t.Fatalf("Error creating user %d: %v", i, err) - } - err = saveUser(u) - if err != nil { - t.Fatalf("Error saving user %d: %v", i, err) - } - - // Change modified time so they're all different, so the test becomes deterministic - f, err := os.Stat(storage.User(eml)) - if err != nil { - t.Fatalf("Could not access user folder for '%s': %v", eml, err) - } - chTime := f.ModTime().Add(-(time.Duration(i) * time.Second)) - if err := os.Chtimes(storage.User(eml), chTime, chTime); err != nil { - t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err) - } - } - actual = getEmail(server.Config{}, true) - if actual != "test4-3@foo.com" { - t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual) - } -} diff --git a/core/parse/dispenser.go b/core/parse/dispenser.go deleted file mode 100644 index 08aa6e76d..000000000 --- a/core/parse/dispenser.go +++ /dev/null @@ -1,251 +0,0 @@ -package parse - -import ( - "errors" - "fmt" - "io" - "strings" -) - -// Dispenser is a type that dispenses tokens, similarly to a lexer, -// except that it can do so with some notion of structure and has -// some really convenient methods. -type Dispenser struct { - filename string - tokens []token - cursor int - nesting int -} - -// NewDispenser returns a Dispenser, ready to use for parsing the given input. -func NewDispenser(filename string, input io.Reader) Dispenser { - return Dispenser{ - filename: filename, - tokens: allTokens(input), - cursor: -1, - } -} - -// NewDispenserTokens returns a Dispenser filled with the given tokens. -func NewDispenserTokens(filename string, tokens []token) Dispenser { - return Dispenser{ - filename: filename, - tokens: tokens, - cursor: -1, - } -} - -// Next loads the next token. Returns true if a token -// was loaded; false otherwise. If false, all tokens -// have been consumed. -func (d *Dispenser) Next() bool { - if d.cursor < len(d.tokens)-1 { - d.cursor++ - return true - } - return false -} - -// NextArg loads the next token if it is on the same -// line. Returns true if a token was loaded; false -// otherwise. If false, all tokens on the line have -// been consumed. It handles imported tokens correctly. -func (d *Dispenser) NextArg() bool { - if d.cursor < 0 { - d.cursor++ - return true - } - if d.cursor >= len(d.tokens) { - return false - } - if d.cursor < len(d.tokens)-1 && - d.tokens[d.cursor].file == d.tokens[d.cursor+1].file && - d.tokens[d.cursor].line+d.numLineBreaks(d.cursor) == d.tokens[d.cursor+1].line { - d.cursor++ - return true - } - return false -} - -// NextLine loads the next token only if it is not on the same -// line as the current token, and returns true if a token was -// loaded; false otherwise. If false, there is not another token -// or it is on the same line. It handles imported tokens correctly. -func (d *Dispenser) NextLine() bool { - if d.cursor < 0 { - d.cursor++ - return true - } - if d.cursor >= len(d.tokens) { - return false - } - if d.cursor < len(d.tokens)-1 && - (d.tokens[d.cursor].file != d.tokens[d.cursor+1].file || - d.tokens[d.cursor].line+d.numLineBreaks(d.cursor) < d.tokens[d.cursor+1].line) { - d.cursor++ - return true - } - return false -} - -// NextBlock can be used as the condition of a for loop -// to load the next token as long as it opens a block or -// is already in a block. It returns true if a token was -// loaded, or false when the block's closing curly brace -// was loaded and thus the block ended. Nested blocks are -// not supported. -func (d *Dispenser) NextBlock() bool { - if d.nesting > 0 { - d.Next() - if d.Val() == "}" { - d.nesting-- - return false - } - return true - } - if !d.NextArg() { // block must open on same line - return false - } - if d.Val() != "{" { - d.cursor-- // roll back if not opening brace - return false - } - d.Next() - if d.Val() == "}" { - // Open and then closed right away - return false - } - d.nesting++ - return true -} - -// IncrNest adds a level of nesting to the dispenser. -func (d *Dispenser) IncrNest() { - d.nesting++ - return -} - -// Val gets the text of the current token. If there is no token -// loaded, it returns empty string. -func (d *Dispenser) Val() string { - if d.cursor < 0 || d.cursor >= len(d.tokens) { - return "" - } - return d.tokens[d.cursor].text -} - -// Line gets the line number of the current token. If there is no token -// loaded, it returns 0. -func (d *Dispenser) Line() int { - if d.cursor < 0 || d.cursor >= len(d.tokens) { - return 0 - } - return d.tokens[d.cursor].line -} - -// File gets the filename of the current token. If there is no token loaded, -// it returns the filename originally given when parsing started. -func (d *Dispenser) File() string { - if d.cursor < 0 || d.cursor >= len(d.tokens) { - return d.filename - } - if tokenFilename := d.tokens[d.cursor].file; tokenFilename != "" { - return tokenFilename - } - return d.filename -} - -// Args is a convenience function that loads the next arguments -// (tokens on the same line) into an arbitrary number of strings -// pointed to in targets. If there are fewer tokens available -// than string pointers, the remaining strings will not be changed -// and false will be returned. If there were enough tokens available -// to fill the arguments, then true will be returned. -func (d *Dispenser) Args(targets ...*string) bool { - enough := true - for i := 0; i < len(targets); i++ { - if !d.NextArg() { - enough = false - break - } - *targets[i] = d.Val() - } - return enough -} - -// RemainingArgs loads any more arguments (tokens on the same line) -// into a slice and returns them. Open curly brace tokens also indicate -// the end of arguments, and the curly brace is not included in -// the return value nor is it loaded. -func (d *Dispenser) RemainingArgs() []string { - var args []string - - for d.NextArg() { - if d.Val() == "{" { - d.cursor-- - break - } - args = append(args, d.Val()) - } - - return args -} - -// ArgErr returns an argument error, meaning that another -// argument was expected but not found. In other words, -// a line break or open curly brace was encountered instead of -// an argument. -func (d *Dispenser) ArgErr() error { - if d.Val() == "{" { - return d.Err("Unexpected token '{', expecting argument") - } - return d.Errf("Wrong argument count or unexpected line ending after '%s'", d.Val()) -} - -// SyntaxErr creates a generic syntax error which explains what was -// found and what was expected. -func (d *Dispenser) SyntaxErr(expected string) error { - msg := fmt.Sprintf("%s:%d - Syntax error: Unexpected token '%s', expecting '%s'", d.File(), d.Line(), d.Val(), expected) - return errors.New(msg) -} - -// EOFErr returns an error indicating that the dispenser reached -// the end of the input when searching for the next token. -func (d *Dispenser) EOFErr() error { - return d.Errf("Unexpected EOF") -} - -// Err generates a custom parse error with a message of msg. -func (d *Dispenser) Err(msg string) error { - msg = fmt.Sprintf("%s:%d - Parse error: %s", d.File(), d.Line(), msg) - return errors.New(msg) -} - -// Errf is like Err, but for formatted error messages -func (d *Dispenser) Errf(format string, args ...interface{}) error { - return d.Err(fmt.Sprintf(format, args...)) -} - -// numLineBreaks counts how many line breaks are in the token -// value given by the token index tknIdx. It returns 0 if the -// token does not exist or there are no line breaks. -func (d *Dispenser) numLineBreaks(tknIdx int) int { - if tknIdx < 0 || tknIdx >= len(d.tokens) { - return 0 - } - return strings.Count(d.tokens[tknIdx].text, "\n") -} - -// isNewLine determines whether the current token is on a different -// line (higher line number) than the previous token. It handles imported -// tokens correctly. If there isn't a previous token, it returns true. -func (d *Dispenser) isNewLine() bool { - if d.cursor < 1 { - return true - } - if d.cursor > len(d.tokens)-1 { - return false - } - return d.tokens[d.cursor-1].file != d.tokens[d.cursor].file || - d.tokens[d.cursor-1].line+d.numLineBreaks(d.cursor-1) < d.tokens[d.cursor].line -} diff --git a/core/parse/dispenser_test.go b/core/parse/dispenser_test.go deleted file mode 100644 index 20a7ddcac..000000000 --- a/core/parse/dispenser_test.go +++ /dev/null @@ -1,292 +0,0 @@ -package parse - -import ( - "reflect" - "strings" - "testing" -) - -func TestDispenser_Val_Next(t *testing.T) { - input := `host:port - dir1 arg1 - dir2 arg2 arg3 - dir3` - d := NewDispenser("Testfile", strings.NewReader(input)) - - if val := d.Val(); val != "" { - t.Fatalf("Val(): Should return empty string when no token loaded; got '%s'", val) - } - - assertNext := func(shouldLoad bool, expectedCursor int, expectedVal string) { - if loaded := d.Next(); loaded != shouldLoad { - t.Errorf("Next(): Expected %v but got %v instead (val '%s')", shouldLoad, loaded, d.Val()) - } - if d.cursor != expectedCursor { - t.Errorf("Expected cursor to be %d, but was %d", expectedCursor, d.cursor) - } - if d.nesting != 0 { - t.Errorf("Nesting should be 0, was %d instead", d.nesting) - } - if val := d.Val(); val != expectedVal { - t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val) - } - } - - assertNext(true, 0, "host:port") - assertNext(true, 1, "dir1") - assertNext(true, 2, "arg1") - assertNext(true, 3, "dir2") - assertNext(true, 4, "arg2") - assertNext(true, 5, "arg3") - assertNext(true, 6, "dir3") - // Note: This next test simply asserts existing behavior. - // If desired, we may wish to empty the token value after - // reading past the EOF. Open an issue if you want this change. - assertNext(false, 6, "dir3") -} - -func TestDispenser_NextArg(t *testing.T) { - input := `dir1 arg1 - dir2 arg2 arg3 - dir3` - d := NewDispenser("Testfile", strings.NewReader(input)) - - assertNext := func(shouldLoad bool, expectedVal string, expectedCursor int) { - if d.Next() != shouldLoad { - t.Errorf("Next(): Should load token but got false instead (val: '%s')", d.Val()) - } - if d.cursor != expectedCursor { - t.Errorf("Next(): Expected cursor to be at %d, but it was %d", expectedCursor, d.cursor) - } - if val := d.Val(); val != expectedVal { - t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val) - } - } - - assertNextArg := func(expectedVal string, loadAnother bool, expectedCursor int) { - if d.NextArg() != true { - t.Error("NextArg(): Should load next argument but got false instead") - } - if d.cursor != expectedCursor { - t.Errorf("NextArg(): Expected cursor to be at %d, but it was %d", expectedCursor, d.cursor) - } - if val := d.Val(); val != expectedVal { - t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val) - } - if !loadAnother { - if d.NextArg() != false { - t.Fatalf("NextArg(): Should NOT load another argument, but got true instead (val: '%s')", d.Val()) - } - if d.cursor != expectedCursor { - t.Errorf("NextArg(): Expected cursor to remain at %d, but it was %d", expectedCursor, d.cursor) - } - } - } - - assertNext(true, "dir1", 0) - assertNextArg("arg1", false, 1) - assertNext(true, "dir2", 2) - assertNextArg("arg2", true, 3) - assertNextArg("arg3", false, 4) - assertNext(true, "dir3", 5) - assertNext(false, "dir3", 5) -} - -func TestDispenser_NextLine(t *testing.T) { - input := `host:port - dir1 arg1 - dir2 arg2 arg3` - d := NewDispenser("Testfile", strings.NewReader(input)) - - assertNextLine := func(shouldLoad bool, expectedVal string, expectedCursor int) { - if d.NextLine() != shouldLoad { - t.Errorf("NextLine(): Should load token but got false instead (val: '%s')", d.Val()) - } - if d.cursor != expectedCursor { - t.Errorf("NextLine(): Expected cursor to be %d, instead was %d", expectedCursor, d.cursor) - } - if val := d.Val(); val != expectedVal { - t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val) - } - } - - assertNextLine(true, "host:port", 0) - assertNextLine(true, "dir1", 1) - assertNextLine(false, "dir1", 1) - d.Next() // arg1 - assertNextLine(true, "dir2", 3) - assertNextLine(false, "dir2", 3) - d.Next() // arg2 - assertNextLine(false, "arg2", 4) - d.Next() // arg3 - assertNextLine(false, "arg3", 5) -} - -func TestDispenser_NextBlock(t *testing.T) { - input := `foobar1 { - sub1 arg1 - sub2 - } - foobar2 { - }` - d := NewDispenser("Testfile", strings.NewReader(input)) - - assertNextBlock := func(shouldLoad bool, expectedCursor, expectedNesting int) { - if loaded := d.NextBlock(); loaded != shouldLoad { - t.Errorf("NextBlock(): Should return %v but got %v", shouldLoad, loaded) - } - if d.cursor != expectedCursor { - t.Errorf("NextBlock(): Expected cursor to be %d, was %d", expectedCursor, d.cursor) - } - if d.nesting != expectedNesting { - t.Errorf("NextBlock(): Nesting should be %d, not %d", expectedNesting, d.nesting) - } - } - - assertNextBlock(false, -1, 0) - d.Next() // foobar1 - assertNextBlock(true, 2, 1) - assertNextBlock(true, 3, 1) - assertNextBlock(true, 4, 1) - assertNextBlock(false, 5, 0) - d.Next() // foobar2 - assertNextBlock(false, 8, 0) // empty block is as if it didn't exist -} - -func TestDispenser_Args(t *testing.T) { - var s1, s2, s3 string - input := `dir1 arg1 arg2 arg3 - dir2 arg4 arg5 - dir3 arg6 arg7 - dir4` - d := NewDispenser("Testfile", strings.NewReader(input)) - - d.Next() // dir1 - - // As many strings as arguments - if all := d.Args(&s1, &s2, &s3); !all { - t.Error("Args(): Expected true, got false") - } - if s1 != "arg1" { - t.Errorf("Args(): Expected s1 to be 'arg1', got '%s'", s1) - } - if s2 != "arg2" { - t.Errorf("Args(): Expected s2 to be 'arg2', got '%s'", s2) - } - if s3 != "arg3" { - t.Errorf("Args(): Expected s3 to be 'arg3', got '%s'", s3) - } - - d.Next() // dir2 - - // More strings than arguments - if all := d.Args(&s1, &s2, &s3); all { - t.Error("Args(): Expected false, got true") - } - if s1 != "arg4" { - t.Errorf("Args(): Expected s1 to be 'arg4', got '%s'", s1) - } - if s2 != "arg5" { - t.Errorf("Args(): Expected s2 to be 'arg5', got '%s'", s2) - } - if s3 != "arg3" { - t.Errorf("Args(): Expected s3 to be unchanged ('arg3'), instead got '%s'", s3) - } - - // (quick cursor check just for kicks and giggles) - if d.cursor != 6 { - t.Errorf("Cursor should be 6, but is %d", d.cursor) - } - - d.Next() // dir3 - - // More arguments than strings - if all := d.Args(&s1); !all { - t.Error("Args(): Expected true, got false") - } - if s1 != "arg6" { - t.Errorf("Args(): Expected s1 to be 'arg6', got '%s'", s1) - } - - d.Next() // dir4 - - // No arguments or strings - if all := d.Args(); !all { - t.Error("Args(): Expected true, got false") - } - - // No arguments but at least one string - if all := d.Args(&s1); all { - t.Error("Args(): Expected false, got true") - } -} - -func TestDispenser_RemainingArgs(t *testing.T) { - input := `dir1 arg1 arg2 arg3 - dir2 arg4 arg5 - dir3 arg6 { arg7 - dir4` - d := NewDispenser("Testfile", strings.NewReader(input)) - - d.Next() // dir1 - - args := d.RemainingArgs() - if expected := []string{"arg1", "arg2", "arg3"}; !reflect.DeepEqual(args, expected) { - t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args) - } - - d.Next() // dir2 - - args = d.RemainingArgs() - if expected := []string{"arg4", "arg5"}; !reflect.DeepEqual(args, expected) { - t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args) - } - - d.Next() // dir3 - - args = d.RemainingArgs() - if expected := []string{"arg6"}; !reflect.DeepEqual(args, expected) { - t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args) - } - - d.Next() // { - d.Next() // arg7 - d.Next() // dir4 - - args = d.RemainingArgs() - if len(args) != 0 { - t.Errorf("RemainingArgs(): Expected %v, got %v", []string{}, args) - } -} - -func TestDispenser_ArgErr_Err(t *testing.T) { - input := `dir1 { - } - dir2 arg1 arg2` - d := NewDispenser("Testfile", strings.NewReader(input)) - - d.cursor = 1 // { - - if err := d.ArgErr(); err == nil || !strings.Contains(err.Error(), "{") { - t.Errorf("ArgErr(): Expected an error message with { in it, but got '%v'", err) - } - - d.cursor = 5 // arg2 - - if err := d.ArgErr(); err == nil || !strings.Contains(err.Error(), "arg2") { - t.Errorf("ArgErr(): Expected an error message with 'arg2' in it; got '%v'", err) - } - - err := d.Err("foobar") - if err == nil { - t.Fatalf("Err(): Expected an error, got nil") - } - - if !strings.Contains(err.Error(), "Testfile:3") { - t.Errorf("Expected error message with filename:line in it; got '%v'", err) - } - - if !strings.Contains(err.Error(), "foobar") { - t.Errorf("Expected error message with custom message in it ('foobar'); got '%v'", err) - } -} diff --git a/core/parse/import_glob0.txt b/core/parse/import_glob0.txt deleted file mode 100644 index e610b5e7c..000000000 --- a/core/parse/import_glob0.txt +++ /dev/null @@ -1,6 +0,0 @@ -glob0.host0 { - dir2 arg1 -} - -glob0.host1 { -} diff --git a/core/parse/import_glob1.txt b/core/parse/import_glob1.txt deleted file mode 100644 index 111eb044d..000000000 --- a/core/parse/import_glob1.txt +++ /dev/null @@ -1,4 +0,0 @@ -glob1.host0 { - dir1 - dir2 arg1 -} diff --git a/core/parse/import_glob2.txt b/core/parse/import_glob2.txt deleted file mode 100644 index c09f784ec..000000000 --- a/core/parse/import_glob2.txt +++ /dev/null @@ -1,3 +0,0 @@ -glob2.host0 { - dir2 arg1 -} diff --git a/core/parse/import_test1.txt b/core/parse/import_test1.txt deleted file mode 100644 index dac7b29be..000000000 --- a/core/parse/import_test1.txt +++ /dev/null @@ -1,2 +0,0 @@ -dir2 arg1 arg2 -dir3 \ No newline at end of file diff --git a/core/parse/import_test2.txt b/core/parse/import_test2.txt deleted file mode 100644 index 140c87939..000000000 --- a/core/parse/import_test2.txt +++ /dev/null @@ -1,4 +0,0 @@ -host1 { - dir1 - dir2 arg1 -} \ No newline at end of file diff --git a/core/parse/lexer.go b/core/parse/lexer.go deleted file mode 100644 index d2939eba2..000000000 --- a/core/parse/lexer.go +++ /dev/null @@ -1,122 +0,0 @@ -package parse - -import ( - "bufio" - "io" - "unicode" -) - -type ( - // lexer is a utility which can get values, token by - // token, from a Reader. A token is a word, and tokens - // are separated by whitespace. A word can be enclosed - // in quotes if it contains whitespace. - lexer struct { - reader *bufio.Reader - token token - line int - } - - // token represents a single parsable unit. - token struct { - file string - line int - text string - } -) - -// load prepares the lexer to scan an input for tokens. -func (l *lexer) load(input io.Reader) error { - l.reader = bufio.NewReader(input) - l.line = 1 - return nil -} - -// next loads the next token into the lexer. -// A token is delimited by whitespace, unless -// the token starts with a quotes character (") -// in which case the token goes until the closing -// quotes (the enclosing quotes are not included). -// Inside quoted strings, quotes may be escaped -// with a preceding \ character. No other chars -// may be escaped. The rest of the line is skipped -// if a "#" character is read in. Returns true if -// a token was loaded; false otherwise. -func (l *lexer) next() bool { - var val []rune - var comment, quoted, escaped bool - - makeToken := func() bool { - l.token.text = string(val) - return true - } - - for { - ch, _, err := l.reader.ReadRune() - if err != nil { - if len(val) > 0 { - return makeToken() - } - if err == io.EOF { - return false - } - panic(err) - } - - if quoted { - if !escaped { - if ch == '\\' { - escaped = true - continue - } else if ch == '"' { - quoted = false - return makeToken() - } - } - if ch == '\n' { - l.line++ - } - if escaped { - // only escape quotes - if ch != '"' { - val = append(val, '\\') - } - } - val = append(val, ch) - escaped = false - continue - } - - if unicode.IsSpace(ch) { - if ch == '\r' { - continue - } - if ch == '\n' { - l.line++ - comment = false - } - if len(val) > 0 { - return makeToken() - } - continue - } - - if ch == '#' { - comment = true - } - - if comment { - continue - } - - if len(val) == 0 { - l.token = token{line: l.line} - if ch == '"' { - quoted = true - continue - } - } - - val = append(val, ch) - } -} diff --git a/core/parse/lexer_test.go b/core/parse/lexer_test.go deleted file mode 100644 index f12c7e7dc..000000000 --- a/core/parse/lexer_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package parse - -import ( - "strings" - "testing" -) - -type lexerTestCase struct { - input string - expected []token -} - -func TestLexer(t *testing.T) { - testCases := []lexerTestCase{ - { - input: `host:123`, - expected: []token{ - {line: 1, text: "host:123"}, - }, - }, - { - input: `host:123 - - directive`, - expected: []token{ - {line: 1, text: "host:123"}, - {line: 3, text: "directive"}, - }, - }, - { - input: `host:123 { - directive - }`, - expected: []token{ - {line: 1, text: "host:123"}, - {line: 1, text: "{"}, - {line: 2, text: "directive"}, - {line: 3, text: "}"}, - }, - }, - { - input: `host:123 { directive }`, - expected: []token{ - {line: 1, text: "host:123"}, - {line: 1, text: "{"}, - {line: 1, text: "directive"}, - {line: 1, text: "}"}, - }, - }, - { - input: `host:123 { - #comment - directive - # comment - foobar # another comment - }`, - expected: []token{ - {line: 1, text: "host:123"}, - {line: 1, text: "{"}, - {line: 3, text: "directive"}, - {line: 5, text: "foobar"}, - {line: 6, text: "}"}, - }, - }, - { - input: `a "quoted value" b - foobar`, - expected: []token{ - {line: 1, text: "a"}, - {line: 1, text: "quoted value"}, - {line: 1, text: "b"}, - {line: 2, text: "foobar"}, - }, - }, - { - input: `A "quoted \"value\" inside" B`, - expected: []token{ - {line: 1, text: "A"}, - {line: 1, text: `quoted "value" inside`}, - {line: 1, text: "B"}, - }, - }, - { - input: `"don't\escape"`, - expected: []token{ - {line: 1, text: `don't\escape`}, - }, - }, - { - input: `"don't\\escape"`, - expected: []token{ - {line: 1, text: `don't\\escape`}, - }, - }, - { - input: `A "quoted value with line - break inside" { - foobar - }`, - expected: []token{ - {line: 1, text: "A"}, - {line: 1, text: "quoted value with line\n\t\t\t\t\tbreak inside"}, - {line: 2, text: "{"}, - {line: 3, text: "foobar"}, - {line: 4, text: "}"}, - }, - }, - { - input: `"C:\php\php-cgi.exe"`, - expected: []token{ - {line: 1, text: `C:\php\php-cgi.exe`}, - }, - }, - { - input: `empty "" string`, - expected: []token{ - {line: 1, text: `empty`}, - {line: 1, text: ``}, - {line: 1, text: `string`}, - }, - }, - { - input: "skip those\r\nCR characters", - expected: []token{ - {line: 1, text: "skip"}, - {line: 1, text: "those"}, - {line: 2, text: "CR"}, - {line: 2, text: "characters"}, - }, - }, - } - - for i, testCase := range testCases { - actual := tokenize(testCase.input) - lexerCompare(t, i, testCase.expected, actual) - } -} - -func tokenize(input string) (tokens []token) { - l := lexer{} - l.load(strings.NewReader(input)) - for l.next() { - tokens = append(tokens, l.token) - } - return -} - -func lexerCompare(t *testing.T, n int, expected, actual []token) { - if len(expected) != len(actual) { - t.Errorf("Test case %d: expected %d token(s) but got %d", n, len(expected), len(actual)) - } - - for i := 0; i < len(actual) && i < len(expected); i++ { - if actual[i].line != expected[i].line { - t.Errorf("Test case %d token %d ('%s'): expected line %d but was line %d", - n, i, expected[i].text, expected[i].line, actual[i].line) - break - } - if actual[i].text != expected[i].text { - t.Errorf("Test case %d token %d: expected text '%s' but was '%s'", - n, i, expected[i].text, actual[i].text) - break - } - } -} diff --git a/core/parse/parse.go b/core/parse/parse.go deleted file mode 100644 index faef36c28..000000000 --- a/core/parse/parse.go +++ /dev/null @@ -1,32 +0,0 @@ -// Package parse provides facilities for parsing configuration files. -package parse - -import "io" - -// ServerBlocks parses the input just enough to organize tokens, -// in order, by server block. No further parsing is performed. -// If checkDirectives is true, only valid directives will be allowed -// otherwise we consider it a parse error. Server blocks are returned -// in the order in which they appear. -func ServerBlocks(filename string, input io.Reader, checkDirectives bool) ([]ServerBlock, error) { - p := parser{Dispenser: NewDispenser(filename, input)} - p.checkDirectives = checkDirectives - blocks, err := p.parseAll() - return blocks, err -} - -// allTokens lexes the entire input, but does not parse it. -// It returns all the tokens from the input, unstructured -// and in order. -func allTokens(input io.Reader) (tokens []token) { - l := new(lexer) - l.load(input) - for l.next() { - tokens = append(tokens, l.token) - } - return -} - -// ValidDirectives is a set of directives that are valid (unordered). Populated -// by config package's init function. -var ValidDirectives = make(map[string]struct{}) diff --git a/core/parse/parse_test.go b/core/parse/parse_test.go deleted file mode 100644 index 48746300f..000000000 --- a/core/parse/parse_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package parse - -import ( - "strings" - "testing" -) - -func TestAllTokens(t *testing.T) { - input := strings.NewReader("a b c\nd e") - expected := []string{"a", "b", "c", "d", "e"} - tokens := allTokens(input) - - if len(tokens) != len(expected) { - t.Fatalf("Expected %d tokens, got %d", len(expected), len(tokens)) - } - - for i, val := range expected { - if tokens[i].text != val { - t.Errorf("Token %d should be '%s' but was '%s'", i, val, tokens[i].text) - } - } -} diff --git a/core/parse/parsing.go b/core/parse/parsing.go deleted file mode 100644 index 7be7d6714..000000000 --- a/core/parse/parsing.go +++ /dev/null @@ -1,388 +0,0 @@ -package parse - -import ( - "fmt" - "net" - "os" - "path/filepath" - "strings" - - "github.com/miekg/dns" -) - -type parser struct { - Dispenser - block ServerBlock // current server block being parsed - eof bool // if we encounter a valid EOF in a hard place - checkDirectives bool // if true, directives must be known -} - -func (p *parser) parseAll() ([]ServerBlock, error) { - var blocks []ServerBlock - - for p.Next() { - err := p.parseOne() - if err != nil { - return blocks, err - } - if len(p.block.Addresses) > 0 { - blocks = append(blocks, p.block) - } - } - - return blocks, nil -} - -func (p *parser) parseOne() error { - p.block = ServerBlock{Tokens: make(map[string][]token)} - - err := p.begin() - if err != nil { - return err - } - - return nil -} - -func (p *parser) begin() error { - if len(p.tokens) == 0 { - return nil - } - - err := p.addresses() - if err != nil { - return err - } - - if p.eof { - // this happens if the Corefile consists of only - // a line of addresses and nothing else - return nil - } - - err = p.blockContents() - if err != nil { - return err - } - - return nil -} - -func (p *parser) addresses() error { - var expectingAnother bool - - for { - tkn := replaceEnvVars(p.Val()) - - // special case: import directive replaces tokens during parse-time - if tkn == "import" && p.isNewLine() { - err := p.doImport() - if err != nil { - return err - } - continue - } - - // Open brace definitely indicates end of addresses - if tkn == "{" { - if expectingAnother { - return p.Errf("Expected another address but had '%s' - check for extra comma", tkn) - } - break - } - - if tkn != "" { // empty token possible if user typed "" in Corefile - // Trailing comma indicates another address will follow, which - // may possibly be on the next line - if tkn[len(tkn)-1] == ',' { - tkn = tkn[:len(tkn)-1] - expectingAnother = true - } else { - expectingAnother = false // but we may still see another one on this line - } - - // Parse and save this address - addr, err := standardAddress(tkn) - if err != nil { - return err - } - p.block.Addresses = append(p.block.Addresses, addr) - } - - // Advance token and possibly break out of loop or return error - hasNext := p.Next() - if expectingAnother && !hasNext { - return p.EOFErr() - } - if !hasNext { - p.eof = true - break // EOF - } - if !expectingAnother && p.isNewLine() { - break - } - } - - return nil -} - -func (p *parser) blockContents() error { - errOpenCurlyBrace := p.openCurlyBrace() - if errOpenCurlyBrace != nil { - // single-server configs don't need curly braces - p.cursor-- - } - - err := p.directives() - if err != nil { - return err - } - - // Only look for close curly brace if there was an opening - if errOpenCurlyBrace == nil { - err = p.closeCurlyBrace() - if err != nil { - return err - } - } - - return nil -} - -// directives parses through all the lines for directives -// and it expects the next token to be the first -// directive. It goes until EOF or closing curly brace -// which ends the server block. -func (p *parser) directives() error { - for p.Next() { - // end of server block - if p.Val() == "}" { - break - } - - // special case: import directive replaces tokens during parse-time - if p.Val() == "import" { - err := p.doImport() - if err != nil { - return err - } - p.cursor-- // cursor is advanced when we continue, so roll back one more - continue - } - - // normal case: parse a directive on this line - if err := p.directive(); err != nil { - return err - } - } - return nil -} - -// doImport swaps out the import directive and its argument -// (a total of 2 tokens) with the tokens in the specified file -// or globbing pattern. When the function returns, the cursor -// is on the token before where the import directive was. In -// other words, call Next() to access the first token that was -// imported. -func (p *parser) doImport() error { - // syntax check - if !p.NextArg() { - return p.ArgErr() - } - importPattern := p.Val() - if p.NextArg() { - return p.Err("Import takes only one argument (glob pattern or file)") - } - - // do glob - matches, err := filepath.Glob(importPattern) - if err != nil { - return p.Errf("Failed to use import pattern %s: %v", importPattern, err) - } - if len(matches) == 0 { - return p.Errf("No files matching import pattern %s", importPattern) - } - - // splice out the import directive and its argument (2 tokens total) - tokensBefore := p.tokens[:p.cursor-1] - tokensAfter := p.tokens[p.cursor+1:] - - // collect all the imported tokens - var importedTokens []token - for _, importFile := range matches { - newTokens, err := p.doSingleImport(importFile) - if err != nil { - return err - } - importedTokens = append(importedTokens, newTokens...) - } - - // splice the imported tokens in the place of the import statement - // and rewind cursor so Next() will land on first imported token - p.tokens = append(tokensBefore, append(importedTokens, tokensAfter...)...) - p.cursor-- - - return nil -} - -// doSingleImport lexes the individual file at importFile and returns -// its tokens or an error, if any. -func (p *parser) doSingleImport(importFile string) ([]token, error) { - file, err := os.Open(importFile) - if err != nil { - return nil, p.Errf("Could not import %s: %v", importFile, err) - } - defer file.Close() - importedTokens := allTokens(file) - - // Tack the filename onto these tokens so errors show the imported file's name - filename := filepath.Base(importFile) - for i := 0; i < len(importedTokens); i++ { - importedTokens[i].file = filename - } - - return importedTokens, nil -} - -// directive collects tokens until the directive's scope -// closes (either end of line or end of curly brace block). -// It expects the currently-loaded token to be a directive -// (or } that ends a server block). The collected tokens -// are loaded into the current server block for later use -// by directive setup functions. -func (p *parser) directive() error { - dir := p.Val() - nesting := 0 - - if p.checkDirectives { - if _, ok := ValidDirectives[dir]; !ok { - return p.Errf("Unknown directive '%s'", dir) - } - } - - // The directive itself is appended as a relevant token - p.block.Tokens[dir] = append(p.block.Tokens[dir], p.tokens[p.cursor]) - - for p.Next() { - if p.Val() == "{" { - nesting++ - } else if p.isNewLine() && nesting == 0 { - p.cursor-- // read too far - break - } else if p.Val() == "}" && nesting > 0 { - nesting-- - } else if p.Val() == "}" && nesting == 0 { - return p.Err("Unexpected '}' because no matching opening brace") - } - p.tokens[p.cursor].text = replaceEnvVars(p.tokens[p.cursor].text) - p.block.Tokens[dir] = append(p.block.Tokens[dir], p.tokens[p.cursor]) - } - - if nesting > 0 { - return p.EOFErr() - } - return nil -} - -// openCurlyBrace expects the current token to be an -// opening curly brace. This acts like an assertion -// because it returns an error if the token is not -// a opening curly brace. It does NOT advance the token. -func (p *parser) openCurlyBrace() error { - if p.Val() != "{" { - return p.SyntaxErr("{") - } - return nil -} - -// closeCurlyBrace expects the current token to be -// a closing curly brace. This acts like an assertion -// because it returns an error if the token is not -// a closing curly brace. It does NOT advance the token. -func (p *parser) closeCurlyBrace() error { - if p.Val() != "}" { - return p.SyntaxErr("}") - } - return nil -} - -// standardAddress parses an address string into a structured format with separate -// host, and port portions, as well as the original input string. -func standardAddress(str string) (address, error) { - var err error - - // first check for scheme and strip it off - input := str - - // separate host and port - host, port, err := net.SplitHostPort(str) - if err != nil { - host, port, err = net.SplitHostPort(str + ":") - // no error check here; return err at end of function - } - - if len(host) > 255 { - return address{}, fmt.Errorf("specified address is too long: %d > 255", len(host)) - } - _, d := dns.IsDomainName(host) - if !d { - return address{}, fmt.Errorf("host is not a valid domain: %s", host) - } - - // see if we can set port based off scheme - if port == "" { - port = "53" - } - - return address{Original: input, Host: strings.ToLower(dns.Fqdn(host)), Port: port}, err -} - -// replaceEnvVars replaces environment variables that appear in the token -// and understands both the $UNIX and %WINDOWS% syntaxes. -func replaceEnvVars(s string) string { - s = replaceEnvReferences(s, "{%", "%}") - s = replaceEnvReferences(s, "{$", "}") - return s -} - -// replaceEnvReferences performs the actual replacement of env variables -// in s, given the placeholder start and placeholder end strings. -func replaceEnvReferences(s, refStart, refEnd string) string { - index := strings.Index(s, refStart) - for index != -1 { - endIndex := strings.Index(s, refEnd) - if endIndex != -1 { - ref := s[index : endIndex+len(refEnd)] - s = strings.Replace(s, ref, os.Getenv(ref[len(refStart):len(ref)-len(refEnd)]), -1) - } else { - return s - } - index = strings.Index(s, refStart) - } - return s -} - -type ( - // ServerBlock associates tokens with a list of addresses - // and groups tokens by directive name. - ServerBlock struct { - Addresses []address - Tokens map[string][]token - } - - address struct { - Original, Host, Port string - } -) - -// HostList converts the list of addresses that are -// associated with this server block into a slice of -// strings, where each address is as it was originally -// read from the input. -func (sb ServerBlock) HostList() []string { - sbHosts := make([]string, len(sb.Addresses)) - for j, addr := range sb.Addresses { - sbHosts[j] = addr.Original - } - return sbHosts -} diff --git a/core/parse/parsing_test.go b/core/parse/parsing_test.go deleted file mode 100644 index d06fc8b58..000000000 --- a/core/parse/parsing_test.go +++ /dev/null @@ -1,401 +0,0 @@ -package parse - -import ( - "os" - "strings" - "testing" -) - -func TestStandardAddress(t *testing.T) { - for i, test := range []struct { - input string - host, port string - shouldErr bool - }{ - {`localhost`, "localhost.", "53", false}, - {`localhost:1234`, "localhost.", "1234", false}, - {`localhost:`, "localhost.", "53", false}, - {`0.0.0.0`, "0.0.0.0.", "53", false}, - {`127.0.0.1:1234`, "127.0.0.1.", "1234", false}, - {`:1234`, ".", "1234", false}, - {`[::1]`, "::1.", "53", false}, - {`[::1]:1234`, "::1.", "1234", false}, - {`:`, ".", "53", false}, - {`localhost:http`, "localhost.", "http", false}, - {`localhost:https`, "localhost.", "https", false}, - {``, ".", "53", false}, - {`::1`, "::1.", "53", true}, - {`localhost::`, "localhost::.", "53", true}, - {`#$%@`, "#$%@.", "53", true}, - } { - actual, err := standardAddress(test.input) - - if err != nil && !test.shouldErr { - t.Errorf("Test %d (%s): Expected no error, but had error: %v", i, test.input, err) - } - if err == nil && test.shouldErr { - t.Errorf("Test %d (%s): Expected error, but had none", i, test.input) - } - - if actual.Host != test.host { - t.Errorf("Test %d (%s): Expected host '%s', got '%s'", i, test.input, test.host, actual.Host) - } - if actual.Port != test.port { - t.Errorf("Test %d (%s): Expected port '%s', got '%s'", i, test.input, test.port, actual.Port) - } - } -} - -func TestParseOneAndImport(t *testing.T) { - setupParseTests() - - testParseOne := func(input string) (ServerBlock, error) { - p := testParser(input) - p.Next() // parseOne doesn't call Next() to start, so we must - err := p.parseOne() - return p.block, err - } - - for i, test := range []struct { - input string - shouldErr bool - addresses []address - tokens map[string]int // map of directive name to number of tokens expected - }{ - {`localhost`, false, []address{ - {"localhost", "localhost.", "53"}, - }, map[string]int{}}, - - {`localhost - dir1`, false, []address{ - {"localhost", "localhost.", "53"}, - }, map[string]int{ - "dir1": 1, - }}, - - {`localhost:1234 - dir1 foo bar`, false, []address{ - {"localhost:1234", "localhost.", "1234"}, - }, map[string]int{ - "dir1": 3, - }}, - - {`localhost { - dir1 - }`, false, []address{ - {"localhost", "localhost.", "53"}, - }, map[string]int{ - "dir1": 1, - }}, - - {`localhost:1234 { - dir1 foo bar - dir2 - }`, false, []address{ - {"localhost:1234", "localhost.", "1234"}, - }, map[string]int{ - "dir1": 3, - "dir2": 1, - }}, - - {`host1:80, host2.com - dir1 foo bar - dir2 baz`, false, []address{ - {"host1:80", "host1.", "80"}, - {"host2.com", "host2.com.", "53"}, - }, map[string]int{ - "dir1": 3, - "dir2": 2, - }}, - - {`127.0.0.1 - dir1 { - bar baz - } - dir2 { - foo bar - }`, false, []address{ - {"127.0.0.1", "127.0.0.1.", "53"}, - }, map[string]int{ - "dir1": 5, - "dir2": 5, - }}, - - {`127.0.0.1 - unknown_directive`, true, []address{ - {"127.0.0.1", "127.0.0.1.", "53"}, - }, map[string]int{}}, - - {`localhost - dir1 { - foo`, true, []address{ - {"localhost", "localhost.", "53"}, - }, map[string]int{ - "dir1": 3, - }}, - - {`localhost - dir1 { - }`, false, []address{ - {"localhost", "localhost.", "53"}, - }, map[string]int{ - "dir1": 3, - }}, - - {`localhost - dir1 { - } }`, true, []address{ - {"localhost", "localhost.", "53"}, - }, map[string]int{ - "dir1": 3, - }}, - - {`localhost - dir1 { - nested { - foo - } - } - dir2 foo bar`, false, []address{ - {"localhost", "localhost.", "53"}, - }, map[string]int{ - "dir1": 7, - "dir2": 3, - }}, - - {``, false, []address{}, map[string]int{}}, - - {`localhost - dir1 arg1 - import import_test1.txt`, false, []address{ - {"localhost", "localhost.", "53"}, - }, map[string]int{ - "dir1": 2, - "dir2": 3, - "dir3": 1, - }}, - - {`import import_test2.txt`, false, []address{ - {"host1", "host1.", "53"}, - }, map[string]int{ - "dir1": 1, - "dir2": 2, - }}, - - {`import import_test1.txt import_test2.txt`, true, []address{}, map[string]int{}}, - - {`import not_found.txt`, true, []address{}, map[string]int{}}, - - {`""`, false, []address{}, map[string]int{}}, - - {``, false, []address{}, map[string]int{}}, - } { - result, err := testParseOne(test.input) - - if test.shouldErr && err == nil { - t.Errorf("Test %d: Expected an error, but didn't get one", i) - } - if !test.shouldErr && err != nil { - t.Errorf("Test %d: Expected no error, but got: %v", i, err) - } - - if len(result.Addresses) != len(test.addresses) { - t.Errorf("Test %d: Expected %d addresses, got %d", - i, len(test.addresses), len(result.Addresses)) - continue - } - for j, addr := range result.Addresses { - if addr.Host != test.addresses[j].Host { - t.Errorf("Test %d, address %d: Expected host to be '%s', but was '%s'", - i, j, test.addresses[j].Host, addr.Host) - } - if addr.Port != test.addresses[j].Port { - t.Errorf("Test %d, address %d: Expected port to be '%s', but was '%s'", - i, j, test.addresses[j].Port, addr.Port) - } - } - - if len(result.Tokens) != len(test.tokens) { - t.Errorf("Test %d: Expected %d directives, had %d", - i, len(test.tokens), len(result.Tokens)) - continue - } - for directive, tokens := range result.Tokens { - if len(tokens) != test.tokens[directive] { - t.Errorf("Test %d, directive '%s': Expected %d tokens, counted %d", - i, directive, test.tokens[directive], len(tokens)) - continue - } - } - } -} - -func TestParseAll(t *testing.T) { - setupParseTests() - - for i, test := range []struct { - input string - shouldErr bool - addresses [][]address // addresses per server block, in order - }{ - {`localhost`, false, [][]address{ - {{"localhost", "localhost.", "53"}}, - }}, - - {`localhost:1234`, false, [][]address{ - {{"localhost:1234", "localhost.", "1234"}}, - }}, - - {`localhost:1234 { - } - localhost:2015 { - }`, false, [][]address{ - {{"localhost:1234", "localhost.", "1234"}}, - {{"localhost:2015", "localhost.", "2015"}}, - }}, - - {`localhost:1234, host2`, false, [][]address{ - {{"localhost:1234", "localhost.", "1234"}, {"host2", "host2.", "53"}}, - }}, - - {`localhost:1234, http://host2,`, true, [][]address{}}, - - {`import import_glob*.txt`, false, [][]address{ - {{"glob0.host0", "glob0.host0.", "53"}}, - {{"glob0.host1", "glob0.host1.", "53"}}, - {{"glob1.host0", "glob1.host0.", "53"}}, - {{"glob2.host0", "glob2.host0.", "53"}}, - }}, - } { - p := testParser(test.input) - blocks, err := p.parseAll() - - if test.shouldErr && err == nil { - t.Errorf("Test %d: Expected an error, but didn't get one", i) - } - if !test.shouldErr && err != nil { - t.Errorf("Test %d: Expected no error, but got: %v", i, err) - } - - if len(blocks) != len(test.addresses) { - t.Errorf("Test %d: Expected %d server blocks, got %d", - i, len(test.addresses), len(blocks)) - continue - } - for j, block := range blocks { - if len(block.Addresses) != len(test.addresses[j]) { - t.Errorf("Test %d: Expected %d addresses in block %d, got %d", - i, len(test.addresses[j]), j, len(block.Addresses)) - continue - } - for k, addr := range block.Addresses { - if addr.Host != test.addresses[j][k].Host { - t.Errorf("Test %d, block %d, address %d: Expected host to be '%s', but was '%s'", - i, j, k, test.addresses[j][k].Host, addr.Host) - } - if addr.Port != test.addresses[j][k].Port { - t.Errorf("Test %d, block %d, address %d: Expected port to be '%s', but was '%s'", - i, j, k, test.addresses[j][k].Port, addr.Port) - } - } - } - } -} - -func TestEnvironmentReplacement(t *testing.T) { - setupParseTests() - - os.Setenv("PORT", "8080") - os.Setenv("ADDRESS", "servername.com") - os.Setenv("FOOBAR", "foobar") - - // basic test; unix-style env vars - p := testParser(`{$ADDRESS}`) - blocks, _ := p.parseAll() - if actual, expected := blocks[0].Addresses[0].Host, "servername.com."; expected != actual { - t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) - } - - // multiple vars per token - p = testParser(`{$ADDRESS}:{$PORT}`) - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Addresses[0].Host, "servername.com."; expected != actual { - t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) - } - if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { - t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) - } - - // windows-style var and unix style in same token - p = testParser(`{%ADDRESS%}:{$PORT}`) - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Addresses[0].Host, "servername.com."; expected != actual { - t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) - } - if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { - t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) - } - - // reverse order - p = testParser(`{$ADDRESS}:{%PORT%}`) - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Addresses[0].Host, "servername.com."; expected != actual { - t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) - } - if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { - t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) - } - - // env var in server block body as argument - p = testParser(":{%PORT%}\ndir1 {$FOOBAR}") - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { - t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) - } - if actual, expected := blocks[0].Tokens["dir1"][1].text, "foobar"; expected != actual { - t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) - } - - // combined windows env vars in argument - p = testParser(":{%PORT%}\ndir1 {%ADDRESS%}/{%FOOBAR%}") - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Tokens["dir1"][1].text, "servername.com/foobar"; expected != actual { - t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) - } - - // malformed env var (windows) - p = testParser(":1234\ndir1 {%ADDRESS}") - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Tokens["dir1"][1].text, "{%ADDRESS}"; expected != actual { - t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) - } - - // malformed (non-existent) env var (unix) - p = testParser(`:{$PORT$}`) - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Addresses[0].Port, "53"; expected != actual { - t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) - } - - // in quoted field - p = testParser(":1234\ndir1 \"Test {$FOOBAR} test\"") - blocks, _ = p.parseAll() - if actual, expected := blocks[0].Tokens["dir1"][1].text, "Test foobar test"; expected != actual { - t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) - } -} - -func setupParseTests() { - // Set up some bogus directives for testing - ValidDirectives = map[string]struct{}{ - "dir1": {}, - "dir2": {}, - "dir3": {}, - } -} - -func testParser(input string) parser { - buf := strings.NewReader(input) - p := parser{Dispenser: NewDispenser("Test", buf), checkDirectives: true} - return p -} diff --git a/core/restart.go b/core/restart.go deleted file mode 100644 index aa77e152d..000000000 --- a/core/restart.go +++ /dev/null @@ -1,177 +0,0 @@ -// +build !windows - -package core - -import ( - "bytes" - "encoding/gob" - "errors" - "io/ioutil" - "log" - "net" - "os" - "os/exec" - "path" - "sync/atomic" - - "github.com/miekg/coredns/core/https" -) - -func init() { - gob.Register(CorefileInput{}) -} - -// Restart restarts the entire application; gracefully with zero -// downtime if on a POSIX-compatible system, or forcefully if on -// Windows but with imperceptibly-short downtime. -// -// The restarted application will use newCorefile as its input -// configuration. If newCorefile is nil, the current (existing) -// Corefile configuration will be used. -// -// Note: The process must exist in the same place on the disk in -// order for this to work. Thus, multiple graceful restarts don't -// work if executing with `go run`, since the binary is cleaned up -// when `go run` sees the initial parent process exit. -func Restart(newCorefile Input) error { - log.Println("[INFO] Restarting") - - if newCorefile == nil { - corefileMu.Lock() - newCorefile = corefile - corefileMu.Unlock() - } - - // Get certificates for any new hosts in the new Corefile without causing downtime - err := getCertsForNewCorefile(newCorefile) - if err != nil { - return errors.New("TLS preload: " + err.Error()) - } - - if len(os.Args) == 0 { // this should never happen, but... - os.Args = []string{""} - } - - // Tell the child that it's a restart - os.Setenv("COREDNS_RESTART", "true") - - // Prepare our payload to the child process - crfileGob := corefileGob{ - ListenerFds: make(map[string]uintptr), - Corefile: newCorefile, - OnDemandTLSCertsIssued: atomic.LoadInt32(https.OnDemandIssuedCount), - } - - // Prepare a pipe to the fork's stdin so it can get the Corefile - rpipe, wpipe, err := os.Pipe() - if err != nil { - return err - } - - // Prepare a pipe that the child process will use to communicate - // its success with us by sending > 0 bytes - sigrpipe, sigwpipe, err := os.Pipe() - if err != nil { - return err - } - - // Pass along relevant file descriptors to child process; ordering - // is very important since we rely on these being in certain positions. - extraFiles := []*os.File{sigwpipe} // fd 3 - - // Add file descriptors of all the sockets - serversMu.Lock() - j := 0 - for _, s := range servers { - extraFiles = append(extraFiles, s.ListenerFd()) - extraFiles = append(extraFiles, s.PacketConnFd()) - // So this will be 0 1 2 3 TCP UDP TCP UDP ... etc. - crfileGob.ListenerFds["tcp"+s.Addr] = uintptr(4 + j) // 4 fds come before any of the listeners - crfileGob.ListenerFds["udp"+s.Addr] = uintptr(4 + j + 1) // add udp after that - j += 2 - } - serversMu.Unlock() - - // Set up the command - cmd := exec.Command(os.Args[0], os.Args[1:]...) - cmd.Stdin = rpipe // fd 0 - cmd.Stdout = os.Stdout // fd 1 - cmd.Stderr = os.Stderr // fd 2 - cmd.ExtraFiles = extraFiles - - // Spawn the child process - err = cmd.Start() - if err != nil { - return err - } - - // Immediately close our dup'ed fds and the write end of our signal pipe - for _, f := range extraFiles { - f.Close() - } - - // Feed Corefile to the child - err = gob.NewEncoder(wpipe).Encode(crfileGob) - if err != nil { - return err - } - wpipe.Close() - - // Run all shutdown functions for the middleware, if child start fails, restart them all... - executeShutdownCallbacks("SIGUSR1") - - // Determine whether child startup succeeded - answer, readErr := ioutil.ReadAll(sigrpipe) - if answer == nil || len(answer) == 0 { - cmdErr := cmd.Wait() // get exit status - log.Printf("[ERROR] Restart: child failed to initialize (%v) - changes not applied", cmdErr) - if readErr != nil { - log.Printf("[ERROR] Restart: additionally, error communicating with child process: %v", readErr) - } - // re-call all startup functions. - // TODO(miek): this needs to be tested, somehow. - executeStartupCallbacks("SIGUSR1") - return errIncompleteRestart - } - - // Looks like child is successful; we can exit gracefully. - return Stop() -} - -func getCertsForNewCorefile(newCorefile Input) error { - // parse the new corefile only up to (and including) TLS - // so we can know what we need to get certs for. - configs, _, _, err := loadConfigsUpToIncludingTLS(path.Base(newCorefile.Path()), bytes.NewReader(newCorefile.Body())) - if err != nil { - return errors.New("loading Corefile: " + err.Error()) - } - - // first mark the configs that are qualified for managed TLS - https.MarkQualified(configs) - - // since we group by bind address to obtain certs, we must call - // EnableTLS to make sure the port is set properly first - // (can ignore error since we aren't actually using the certs) - https.EnableTLS(configs, false) - - // find out if we can let the acme package start its own challenge listener - // on port 80 - var proxyACME bool - serversMu.Lock() - for _, s := range servers { - _, port, _ := net.SplitHostPort(s.Addr) - if port == "80" { - proxyACME = true - break - } - } - serversMu.Unlock() - - // place certs on the disk - err = https.ObtainCerts(configs, false, proxyACME) - if err != nil { - return errors.New("obtaining certs: " + err.Error()) - } - - return nil -} diff --git a/core/restart_windows.go b/core/restart_windows.go deleted file mode 100644 index 8a0805a19..000000000 --- a/core/restart_windows.go +++ /dev/null @@ -1,31 +0,0 @@ -package core - -import "log" - -// Restart restarts CoreDNS forcefully using newCorefile, -// or, if nil, the current/existing Corefile is reused. -func Restart(newCorefile Input) error { - log.Println("[INFO] Restarting") - - if newCorefile == nil { - corefileMu.Lock() - newCorefile = corefile - corefileMu.Unlock() - } - - wg.Add(1) // barrier so Wait() doesn't unblock - - err := Stop() - if err != nil { - return err - } - - err = Start(newCorefile) - if err != nil { - return err - } - - wg.Done() // take down our barrier - - return nil -} diff --git a/core/setup/bindhost.go b/core/setup/bindhost.go deleted file mode 100644 index a3c07e5eb..000000000 --- a/core/setup/bindhost.go +++ /dev/null @@ -1,13 +0,0 @@ -package setup - -import "github.com/miekg/coredns/middleware" - -// BindHost sets the host to bind the listener to. -func BindHost(c *Controller) (middleware.Middleware, error) { - for c.Next() { - if !c.Args(&c.BindHost) { - return nil, c.ArgErr() - } - } - return nil, nil -} diff --git a/core/setup/chaos.go b/core/setup/chaos.go deleted file mode 100644 index 13103adf4..000000000 --- a/core/setup/chaos.go +++ /dev/null @@ -1,45 +0,0 @@ -package setup - -import ( - "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/chaos" -) - -// Chaos configures a new Chaos middleware instance. -func Chaos(c *Controller) (middleware.Middleware, error) { - version, authors, err := chaosParse(c) - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - return chaos.Chaos{ - Next: next, - Version: version, - Authors: authors, - } - }, nil -} - -func chaosParse(c *Controller) (string, map[string]bool, error) { - version := "" - authors := make(map[string]bool) - - for c.Next() { - args := c.RemainingArgs() - if len(args) == 0 { - return defaultVersion, nil, nil - } - if len(args) == 1 { - return args[0], nil, nil - } - version = args[0] - for _, a := range args[1:] { - authors[a] = true - } - return version, authors, nil - } - return version, authors, nil -} - -const defaultVersion = "CoreDNS" diff --git a/core/setup/controller.go b/core/setup/controller.go deleted file mode 100644 index 7f8da6721..000000000 --- a/core/setup/controller.go +++ /dev/null @@ -1,85 +0,0 @@ -package setup - -import ( - "fmt" - "strings" - - "golang.org/x/net/context" - - "github.com/miekg/coredns/core/parse" - "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/server" - "github.com/miekg/dns" -) - -// Controller is given to the setup function of middlewares which -// gives them access to be able to read tokens and set config. Each -// virtualhost gets their own server config and dispenser. -type Controller struct { - *server.Config - parse.Dispenser - - // OncePerServerBlock is a function that executes f - // exactly once per server block, no matter how many - // hosts are associated with it. If it is the first - // time, the function f is executed immediately - // (not deferred) and may return an error which is - // returned by OncePerServerBlock. - OncePerServerBlock func(f func() error) error - - // ServerBlockIndex is the 0-based index of the - // server block as it appeared in the input. - ServerBlockIndex int - - // ServerBlockHostIndex is the 0-based index of this - // host as it appeared in the input at the head of the - // server block. - ServerBlockHostIndex int - - // ServerBlockHosts is a list of hosts that are - // associated with this server block. All these - // hosts, consequently, share the same tokens. - ServerBlockHosts []string - - // ServerBlockStorage is used by a directive's - // setup function to persist state between all - // the hosts on a server block. - ServerBlockStorage interface{} -} - -// NewTestController creates a new *Controller for -// the input specified, with a filename of "Testfile". -// The Config is bare, consisting only of a Root of cwd. -// -// Used primarily for testing but needs to be exported so -// add-ons can use this as a convenience. Does not initialize -// the server-block-related fields. -func NewTestController(input string) *Controller { - return &Controller{ - Config: &server.Config{ - Root: ".", - }, - Dispenser: parse.NewDispenser("Testfile", strings.NewReader(input)), - OncePerServerBlock: func(f func() error) error { - return f() - }, - } -} - -// EmptyNext is a no-op function that can be passed into -// middleware.Middleware functions so that the assignment -// to the Next field of the Handler can be tested. -// -// Used primarily for testing but needs to be exported so -// add-ons can use this as a convenience. -var EmptyNext = middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - return 0, nil -}) - -// SameNext does a pointer comparison between next1 and next2. -// -// Used primarily for testing but needs to be exported so -// add-ons can use this as a convenience. -func SameNext(next1, next2 middleware.Handler) bool { - return fmt.Sprintf("%v", next1) == fmt.Sprintf("%v", next2) -} diff --git a/core/setup/health.go b/core/setup/health.go deleted file mode 100644 index 542cb3260..000000000 --- a/core/setup/health.go +++ /dev/null @@ -1,34 +0,0 @@ -package setup - -import ( - "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/health" -) - -func Health(c *Controller) (middleware.Middleware, error) { - addr, err := parseHealth(c) - if err != nil { - return nil, err - } - - h := &health.Health{Addr: addr} - c.Startup = append(c.Startup, h.Start) - c.Shutdown = append(c.Shutdown, h.Shutdown) - return nil, nil -} - -func parseHealth(c *Controller) (string, error) { - addr := "" - for c.Next() { - args := c.RemainingArgs() - - switch len(args) { - case 0: - case 1: - addr = args[0] - default: - return "", c.ArgErr() - } - } - return addr, nil -} diff --git a/core/setup/loadbalance.go b/core/setup/loadbalance.go deleted file mode 100644 index 4b132489b..000000000 --- a/core/setup/loadbalance.go +++ /dev/null @@ -1,16 +0,0 @@ -package setup - -import ( - "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/loadbalance" -) - -// Loadbalance sets up the root file path of the server. -func Loadbalance(c *Controller) (middleware.Middleware, error) { - for c.Next() { - // TODO(miek): block and option parsing - } - return func(next middleware.Handler) middleware.Handler { - return loadbalance.RoundRobin{Next: next} - }, nil -} diff --git a/core/setup/metrics.go b/core/setup/metrics.go deleted file mode 100644 index e88d93c86..000000000 --- a/core/setup/metrics.go +++ /dev/null @@ -1,72 +0,0 @@ -package setup - -import ( - "sync" - - "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/metrics" -) - -const addr = "localhost:9153" - -var metricsOnce sync.Once - -func Prometheus(c *Controller) (middleware.Middleware, error) { - m, err := parsePrometheus(c) - if err != nil { - return nil, err - } - - metricsOnce.Do(func() { - c.Startup = append(c.Startup, m.Start) - c.Shutdown = append(c.Shutdown, m.Shutdown) - }) - - return func(next middleware.Handler) middleware.Handler { - m.Next = next - return m - }, nil -} - -func parsePrometheus(c *Controller) (metrics.Metrics, error) { - var ( - met metrics.Metrics - err error - ) - - for c.Next() { - if len(met.ZoneNames) > 0 { - return metrics.Metrics{}, c.Err("metrics: can only have one metrics module per server") - } - met = metrics.Metrics{ZoneNames: c.ServerBlockHosts} - for i, _ := range met.ZoneNames { - met.ZoneNames[i] = middleware.Host(met.ZoneNames[i]).Normalize() - } - args := c.RemainingArgs() - - switch len(args) { - case 0: - case 1: - met.Addr = args[0] - default: - return metrics.Metrics{}, c.ArgErr() - } - for c.NextBlock() { - switch c.Val() { - case "address": - args = c.RemainingArgs() - if len(args) != 1 { - return metrics.Metrics{}, c.ArgErr() - } - met.Addr = args[0] - default: - return metrics.Metrics{}, c.Errf("metrics: unknown item: %s", c.Val()) - } - - } - } - if met.Addr == "" { - met.Addr = addr - } - return met, err -} diff --git a/core/setup/pprof.go b/core/setup/pprof.go deleted file mode 100644 index 125d2a9ef..000000000 --- a/core/setup/pprof.go +++ /dev/null @@ -1,33 +0,0 @@ -package setup - -import ( - "sync" - - "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/pprof" -) - -var pprofOnce sync.Once - -// PProf returns a new instance of a pprof handler. It accepts no arguments or options. -func PProf(c *Controller) (middleware.Middleware, error) { - found := false - for c.Next() { - if found { - return nil, c.Err("pprof can only be specified once") - } - if len(c.RemainingArgs()) != 0 { - return nil, c.ArgErr() - } - if c.NextBlock() { - return nil, c.ArgErr() - } - found = true - } - handler := &pprof.Handler{} - pprofOnce.Do(func() { - c.Startup = append(c.Startup, handler.Start) - c.Shutdown = append(c.Shutdown, handler.Shutdown) - }) - return nil, nil -} diff --git a/core/setup/proxy.go b/core/setup/proxy.go deleted file mode 100644 index 6753d07ad..000000000 --- a/core/setup/proxy.go +++ /dev/null @@ -1,17 +0,0 @@ -package setup - -import ( - "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/proxy" -) - -// Proxy configures a new Proxy middleware instance. -func Proxy(c *Controller) (middleware.Middleware, error) { - upstreams, err := proxy.NewStaticUpstreams(c.Dispenser) - if err != nil { - return nil, err - } - return func(next middleware.Handler) middleware.Handler { - return proxy.Proxy{Next: next, Client: proxy.Clients(), Upstreams: upstreams} - }, nil -} diff --git a/core/setup/rewrite_test.go b/core/setup/rewrite_test.go deleted file mode 100644 index 5345c4bf6..000000000 --- a/core/setup/rewrite_test.go +++ /dev/null @@ -1,234 +0,0 @@ -package setup - -/* -func TestRewrite(t *testing.T) { - c := NewTestController(`rewrite /from /to`) - - mid, err := Rewrite(c) - if err != nil { - t.Errorf("Expected no errors, but got: %v", err) - } - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(rewrite.Rewrite) - if !ok { - t.Fatalf("Expected handler to be type Rewrite, got: %#v", handler) - } - - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } - - if len(myHandler.Rules) != 1 { - t.Errorf("Expected handler to have %d rule, has %d instead", 1, len(myHandler.Rules)) - } -} - -func TestRewriteParse(t *testing.T) { - simpleTests := []struct { - input string - shouldErr bool - expected []rewrite.Rule - }{ - {`rewrite /from /to`, false, []rewrite.Rule{ - rewrite.SimpleRule{From: "/from", To: "/to"}, - }}, - {`rewrite /from /to - rewrite a b`, false, []rewrite.Rule{ - rewrite.SimpleRule{From: "/from", To: "/to"}, - rewrite.SimpleRule{From: "a", To: "b"}, - }}, - {`rewrite a`, true, []rewrite.Rule{}}, - {`rewrite`, true, []rewrite.Rule{}}, - {`rewrite a b c`, false, []rewrite.Rule{ - rewrite.SimpleRule{From: "a", To: "b c"}, - }}, - } - - for i, test := range simpleTests { - c := NewTestController(test.input) - actual, err := rewriteParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } else if err != nil && test.shouldErr { - continue - } - - if len(actual) != len(test.expected) { - t.Fatalf("Test %d expected %d rules, but got %d", - i, len(test.expected), len(actual)) - } - - for j, e := range test.expected { - actualRule := actual[j].(rewrite.SimpleRule) - expectedRule := e.(rewrite.SimpleRule) - - if actualRule.From != expectedRule.From { - t.Errorf("Test %d, rule %d: Expected From=%s, got %s", - i, j, expectedRule.From, actualRule.From) - } - - if actualRule.To != expectedRule.To { - t.Errorf("Test %d, rule %d: Expected To=%s, got %s", - i, j, expectedRule.To, actualRule.To) - } - } - } - - regexpTests := []struct { - input string - shouldErr bool - expected []rewrite.Rule - }{ - {`rewrite { - r .* - to /to /index.php? - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", To: "/to /index.php?", Regexp: regexp.MustCompile(".*")}, - }}, - {`rewrite { - regexp .* - to /to - ext / html txt - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")}, - }}, - {`rewrite /path { - r rr - to /dest - } - rewrite / { - regexp [a-z]+ - to /to /to2 - } - `, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")}, - &rewrite.ComplexRule{Base: "/", To: "/to /to2", Regexp: regexp.MustCompile("[a-z]+")}, - }}, - {`rewrite { - r .* - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite /`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - to /to - if {path} is a - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", To: "/to", Ifs: []rewrite.If{{A: "{path}", Operator: "is", B: "a"}}}, - }}, - {`rewrite { - status 500 - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - status 400 - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", Status: 400}, - }}, - {`rewrite { - to /to - status 400 - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", To: "/to", Status: 400}, - }}, - {`rewrite { - status 399 - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - status 200 - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", Status: 200}, - }}, - {`rewrite { - to /to - status 200 - }`, false, []rewrite.Rule{ - &rewrite.ComplexRule{Base: "/", To: "/to", Status: 200}, - }}, - {`rewrite { - status 199 - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - status 0 - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - {`rewrite { - to /to - status 0 - }`, true, []rewrite.Rule{ - &rewrite.ComplexRule{}, - }}, - } - - for i, test := range regexpTests { - c := NewTestController(test.input) - actual, err := rewriteParse(c) - - if err == nil && test.shouldErr { - t.Errorf("Test %d didn't error, but it should have", i) - } else if err != nil && !test.shouldErr { - t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) - } else if err != nil && test.shouldErr { - continue - } - - if len(actual) != len(test.expected) { - t.Fatalf("Test %d expected %d rules, but got %d", - i, len(test.expected), len(actual)) - } - - for j, e := range test.expected { - actualRule := actual[j].(*rewrite.ComplexRule) - expectedRule := e.(*rewrite.ComplexRule) - - if actualRule.Base != expectedRule.Base { - t.Errorf("Test %d, rule %d: Expected Base=%s, got %s", - i, j, expectedRule.Base, actualRule.Base) - } - - if actualRule.To != expectedRule.To { - t.Errorf("Test %d, rule %d: Expected To=%s, got %s", - i, j, expectedRule.To, actualRule.To) - } - - if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) { - t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v", - i, j, expectedRule.To, actualRule.To) - } - - if actualRule.Regexp != nil { - if actualRule.String() != expectedRule.String() { - t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", - i, j, expectedRule.String(), actualRule.String()) - } - } - - if fmt.Sprint(actualRule.Ifs) != fmt.Sprint(expectedRule.Ifs) { - t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", - i, j, fmt.Sprint(expectedRule.Ifs), fmt.Sprint(actualRule.Ifs)) - } - - } - } -} -*/ diff --git a/core/setup/roller.go b/core/setup/roller.go deleted file mode 100644 index fd772cc47..000000000 --- a/core/setup/roller.go +++ /dev/null @@ -1,40 +0,0 @@ -package setup - -import ( - "strconv" - - "github.com/miekg/coredns/middleware" -) - -func parseRoller(c *Controller) (*middleware.LogRoller, error) { - var size, age, keep int - // This is kind of a hack to support nested blocks: - // As we are already in a block: either log or errors, - // c.nesting > 0 but, as soon as c meets a }, it thinks - // the block is over and return false for c.NextBlock. - for c.NextBlock() { - what := c.Val() - if !c.NextArg() { - return nil, c.ArgErr() - } - value := c.Val() - var err error - switch what { - case "size": - size, err = strconv.Atoi(value) - case "age": - age, err = strconv.Atoi(value) - case "keep": - keep, err = strconv.Atoi(value) - } - if err != nil { - return nil, err - } - } - return &middleware.LogRoller{ - MaxSize: size, - MaxAge: age, - MaxBackups: keep, - LocalTime: true, - }, nil -} diff --git a/core/setup/root.go b/core/setup/root.go deleted file mode 100644 index 0fce5f170..000000000 --- a/core/setup/root.go +++ /dev/null @@ -1,32 +0,0 @@ -package setup - -import ( - "log" - "os" - - "github.com/miekg/coredns/middleware" -) - -// Root sets up the root file path of the server. -func Root(c *Controller) (middleware.Middleware, error) { - for c.Next() { - if !c.NextArg() { - return nil, c.ArgErr() - } - c.Root = c.Val() - } - - // Check if root path exists - _, err := os.Stat(c.Root) - if err != nil { - if os.IsNotExist(err) { - // Allow this, because the folder might appear later. - // But make sure the user knows! - log.Printf("[WARNING] Root path does not exist: %s", c.Root) - } else { - return nil, c.Errf("Unable to access root path '%s': %v", c.Root, err) - } - } - - return nil, nil -} diff --git a/core/setup/root_test.go b/core/setup/root_test.go deleted file mode 100644 index 8b38e6d04..000000000 --- a/core/setup/root_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package setup - -import ( - "fmt" - "io/ioutil" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestRoot(t *testing.T) { - - // Predefined error substrings - parseErrContent := "Parse error:" - unableToAccessErrContent := "Unable to access root path" - - existingDirPath, err := getTempDirPath() - if err != nil { - t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err) - } - - nonExistingDir := filepath.Join(existingDirPath, "highly_unlikely_to_exist_dir") - - existingFile, err := ioutil.TempFile("", "root_test") - if err != nil { - t.Fatalf("BeforeTest: Failed to create temp file for testing! Error was: %v", err) - } - defer func() { - existingFile.Close() - os.Remove(existingFile.Name()) - }() - - inaccessiblePath := getInaccessiblePath(existingFile.Name()) - - tests := []struct { - input string - shouldErr bool - expectedRoot string // expected root, set to the controller. Empty for negative cases. - expectedErrContent string // substring from the expected error. Empty for positive cases. - }{ - // positive - { - fmt.Sprintf(`root %s`, nonExistingDir), false, nonExistingDir, "", - }, - { - fmt.Sprintf(`root %s`, existingDirPath), false, existingDirPath, "", - }, - // negative - { - `root `, true, "", parseErrContent, - }, - { - fmt.Sprintf(`root %s`, inaccessiblePath), true, "", unableToAccessErrContent, - }, - { - fmt.Sprintf(`root { - %s - }`, existingDirPath), true, "", parseErrContent, - }, - } - - for i, test := range tests { - c := NewTestController(test.input) - mid, err := Root(c) - - if test.shouldErr && err == nil { - t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) - } - - if err != nil { - if !test.shouldErr { - t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) - } - - if !strings.Contains(err.Error(), test.expectedErrContent) { - t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) - } - } - - // the Root method always returns a nil middleware - if mid != nil { - t.Errorf("Middware, returned from Root() was not nil: %v", mid) - } - - // check c.Root only if we are in a positive test. - if !test.shouldErr && test.expectedRoot != c.Root { - t.Errorf("Root not correctly set for input %s. Expected: %s, actual: %s", test.input, test.expectedRoot, c.Root) - } - } -} - -// getTempDirPath returnes the path to the system temp directory. If it does not exists - an error is returned. -func getTempDirPath() (string, error) { - tempDir := os.TempDir() - - _, err := os.Stat(tempDir) - if err != nil { - return "", err - } - - return tempDir, nil -} - -func getInaccessiblePath(file string) string { - // null byte in filename is not allowed on Windows AND unix - return filepath.Join("C:", "file\x00name") -} diff --git a/core/setup/startupshutdown.go b/core/setup/startupshutdown.go deleted file mode 100644 index 1cf2c62e0..000000000 --- a/core/setup/startupshutdown.go +++ /dev/null @@ -1,64 +0,0 @@ -package setup - -import ( - "os" - "os/exec" - "strings" - - "github.com/miekg/coredns/middleware" -) - -// Startup registers a startup callback to execute during server start. -func Startup(c *Controller) (middleware.Middleware, error) { - return nil, registerCallback(c, &c.FirstStartup) -} - -// Shutdown registers a shutdown callback to execute during process exit. -func Shutdown(c *Controller) (middleware.Middleware, error) { - return nil, registerCallback(c, &c.Shutdown) -} - -// registerCallback registers a callback function to execute by -// using c to parse the line. It appends the callback function -// to the list of callback functions passed in by reference. -func registerCallback(c *Controller, list *[]func() error) error { - var funcs []func() error - - for c.Next() { - args := c.RemainingArgs() - if len(args) == 0 { - return c.ArgErr() - } - - nonblock := false - if len(args) > 1 && args[len(args)-1] == "&" { - // Run command in background; non-blocking - nonblock = true - args = args[:len(args)-1] - } - - command, args, err := middleware.SplitCommandAndArgs(strings.Join(args, " ")) - if err != nil { - return c.Err(err.Error()) - } - - fn := func() error { - cmd := exec.Command(command, args...) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - if nonblock { - return cmd.Start() - } - return cmd.Run() - } - - funcs = append(funcs, fn) - } - - return c.OncePerServerBlock(func() error { - *list = append(*list, funcs...) - return nil - }) -} diff --git a/core/setup/startupshutdown_test.go b/core/setup/startupshutdown_test.go deleted file mode 100644 index 871a64214..000000000 --- a/core/setup/startupshutdown_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package setup - -import ( - "os" - "path/filepath" - "strconv" - "testing" - "time" -) - -// The Startup function's tests are symmetrical to Shutdown tests, -// because the Startup and Shutdown functions share virtually the -// same functionality -func TestStartup(t *testing.T) { - tempDirPath, err := getTempDirPath() - if err != nil { - t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err) - } - - testDir := filepath.Join(tempDirPath, "temp_dir_for_testing_startupshutdown") - defer func() { - // clean up after non-blocking startup function quits - time.Sleep(500 * time.Millisecond) - os.RemoveAll(testDir) - }() - osSenitiveTestDir := filepath.FromSlash(testDir) - os.RemoveAll(osSenitiveTestDir) // start with a clean slate - - tests := []struct { - input string - shouldExecutionErr bool - shouldRemoveErr bool - }{ - // test case #0 tests proper functionality blocking commands - {"startup mkdir " + osSenitiveTestDir, false, false}, - - // test case #1 tests proper functionality of non-blocking commands - {"startup mkdir " + osSenitiveTestDir + " &", false, true}, - - // test case #2 tests handling of non-existent commands - {"startup " + strconv.Itoa(int(time.Now().UnixNano())), true, true}, - } - - for i, test := range tests { - c := NewTestController(test.input) - _, err = Startup(c) - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - err = c.FirstStartup[0]() - if err != nil && !test.shouldExecutionErr { - t.Errorf("Test %d recieved an error of:\n%v", i, err) - } - err = os.Remove(osSenitiveTestDir) - if err != nil && !test.shouldRemoveErr { - t.Errorf("Test %d recieved an error of:\n%v", i, err) - } - } -} diff --git a/core/setup/testdata/blog/first_post.md b/core/setup/testdata/blog/first_post.md deleted file mode 100644 index f26583b75..000000000 --- a/core/setup/testdata/blog/first_post.md +++ /dev/null @@ -1 +0,0 @@ -# Test h1 diff --git a/core/setup/testdata/header.html b/core/setup/testdata/header.html deleted file mode 100644 index 9c96e0e37..000000000 --- a/core/setup/testdata/header.html +++ /dev/null @@ -1 +0,0 @@ -

Header title

diff --git a/core/setup/testdata/tpl_with_include.html b/core/setup/testdata/tpl_with_include.html deleted file mode 100644 index 95eeae0c8..000000000 --- a/core/setup/testdata/tpl_with_include.html +++ /dev/null @@ -1,10 +0,0 @@ - - - -{{.Doc.title}} - - -{{.Include "header.html"}} -{{.Doc.body}} - - diff --git a/core/sigtrap.go b/core/sigtrap.go deleted file mode 100644 index f40dd971a..000000000 --- a/core/sigtrap.go +++ /dev/null @@ -1,93 +0,0 @@ -package core - -import ( - "log" - "os" - "os/signal" - "sync" - - "github.com/miekg/coredns/server" -) - -// TrapSignals create signal handlers for all applicable signals for this -// system. If your Go program uses signals, this is a rather invasive -// function; best to implement them yourself in that case. Signals are not -// required for the caddy package to function properly, but this is a -// convenient way to allow the user to control this package of your program. -func TrapSignals() { - trapSignalsCrossPlatform() - trapSignalsPosix() -} - -// trapSignalsCrossPlatform captures SIGINT, which triggers forceful -// shutdown that executes shutdown callbacks first. A second interrupt -// signal will exit the process immediately. -func trapSignalsCrossPlatform() { - go func() { - shutdown := make(chan os.Signal, 1) - signal.Notify(shutdown, os.Interrupt) - - for i := 0; true; i++ { - <-shutdown - - if i > 0 { - log.Println("[INFO] SIGINT: Force quit") - if PidFile != "" { - os.Remove(PidFile) - } - os.Exit(1) - } - - log.Println("[INFO] SIGINT: Shutting down") - - if PidFile != "" { - os.Remove(PidFile) - } - - go os.Exit(executeShutdownCallbacks("SIGINT")) - } - }() -} - -// executeShutdownCallbacks executes the shutdown callbacks as initiated -// by signame. It logs any errors and returns the recommended exit status. -// This function is idempotent; subsequent invocations always return 0. -func executeShutdownCallbacks(signame string) (exitCode int) { - shutdownCallbacksOnce.Do(func() { - serversMu.Lock() - errs := server.ShutdownCallbacks(servers) - serversMu.Unlock() - - if len(errs) > 0 { - for _, err := range errs { - log.Printf("[ERROR] %s shutdown: %v", signame, err) - } - exitCode = 1 - } - }) - return -} - -// executeStartupCallbacks executes the startup callbacks as initiated -// by signame. This is used when on restart when the child failed to start and -// all middleware executed their shutdown functions -func executeStartupCallbacks(signame string) (exitCode int) { - startupCallbacksOnce.Do(func() { - serversMu.Lock() - errs := server.StartupCallbacks(servers) - serversMu.Unlock() - - if len(errs) > 0 { - for _, err := range errs { - log.Printf("[ERROR] %s shutdown: %v", signame, err) - } - exitCode = 1 - } - }) - return -} - -var ( - shutdownCallbacksOnce sync.Once - startupCallbacksOnce sync.Once -) diff --git a/core/sigtrap_posix.go b/core/sigtrap_posix.go deleted file mode 100644 index d120dd52b..000000000 --- a/core/sigtrap_posix.go +++ /dev/null @@ -1,79 +0,0 @@ -// +build !windows - -package core - -import ( - "io/ioutil" - "log" - "os" - "os/signal" - "syscall" -) - -// trapSignalsPosix captures POSIX-only signals. -func trapSignalsPosix() { - go func() { - sigchan := make(chan os.Signal, 1) - signal.Notify(sigchan, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGUSR1) - - for sig := range sigchan { - switch sig { - case syscall.SIGTERM: - log.Println("[INFO] SIGTERM: Terminating process") - if PidFile != "" { - os.Remove(PidFile) - } - os.Exit(0) - - case syscall.SIGQUIT: - log.Println("[INFO] SIGQUIT: Shutting down") - exitCode := executeShutdownCallbacks("SIGQUIT") - err := Stop() - if err != nil { - log.Printf("[ERROR] SIGQUIT stop: %v", err) - exitCode = 1 - } - if PidFile != "" { - os.Remove(PidFile) - } - os.Exit(exitCode) - - case syscall.SIGHUP: - log.Println("[INFO] SIGHUP: Hanging up") - err := Stop() - if err != nil { - log.Printf("[ERROR] SIGHUP stop: %v", err) - } - - case syscall.SIGUSR1: - log.Println("[INFO] SIGUSR1: Reloading") - - var updatedCorefile Input - - corefileMu.Lock() - if corefile == nil { - // Hmm, did spawing process forget to close stdin? Anyhow, this is unusual. - log.Println("[ERROR] SIGUSR1: no Corefile to reload (was stdin left open?)") - corefileMu.Unlock() - continue - } - if corefile.IsFile() { - body, err := ioutil.ReadFile(corefile.Path()) - if err == nil { - updatedCorefile = CorefileInput{ - Filepath: corefile.Path(), - Contents: body, - RealFile: true, - } - } - } - corefileMu.Unlock() - - err := Restart(updatedCorefile) - if err != nil { - log.Printf("[ERROR] SIGUSR1: %v", err) - } - } - } - }() -} diff --git a/core/sigtrap_windows.go b/core/sigtrap_windows.go deleted file mode 100644 index 59132cee4..000000000 --- a/core/sigtrap_windows.go +++ /dev/null @@ -1,3 +0,0 @@ -package core - -func trapSignalsPosix() {} diff --git a/coredns.go b/coredns.go new file mode 100644 index 000000000..e7e54dc60 --- /dev/null +++ b/coredns.go @@ -0,0 +1,18 @@ +package main + +import ( + "flag" + + "github.com/mholt/caddy" + "github.com/mholt/caddy/caddy/caddymain" +) + +//go:generate go run plugin_generate.go + +func main() { + // Set some flags/options specific for CoreDNS. + flag.Set("type", "dns") + caddy.DefaultConfigFile = "Corefile" + + caddymain.Run() +} diff --git a/main.go b/main.go deleted file mode 100644 index 1900fc60c..000000000 --- a/main.go +++ /dev/null @@ -1,7 +0,0 @@ -package main - -import "github.com/miekg/coredns/core/coremain" - -func main() { - coremain.Run() -} diff --git a/middleware/bind/README.md b/middleware/bind/README.md new file mode 100644 index 000000000..ad23b6153 --- /dev/null +++ b/middleware/bind/README.md @@ -0,0 +1,21 @@ +# bind + +bind overrides the host to which the server should bind. Normally, the listener binds to the +wildcard host. However, you may force the listener to bind to another IP instead. This +directive accepts only an address, not a port. + +## Syntax + +~~~ txt +bind address +~~~ + +address is the IP address to bind to. + +## Examples + +To make your socket accessible only to that machine, bind to IP 127.0.0.1 (localhost): + +~~~ txt +bind 127.0.0.1 +~~~ diff --git a/middleware/bind/bind.go b/middleware/bind/bind.go new file mode 100644 index 000000000..ac27c993b --- /dev/null +++ b/middleware/bind/bind.go @@ -0,0 +1,10 @@ +package bind + +import "github.com/mholt/caddy" + +func init() { + caddy.RegisterPlugin("bind", caddy.Plugin{ + ServerType: "dns", + Action: setupBind, + }) +} diff --git a/middleware/bind/bind_test.go b/middleware/bind/bind_test.go new file mode 100644 index 000000000..d61741a02 --- /dev/null +++ b/middleware/bind/bind_test.go @@ -0,0 +1,30 @@ +package bind + +import ( + "testing" + + "github.com/miekg/coredns/core/dnsserver" + + "github.com/mholt/caddy" +) + +func TestSetupBind(t *testing.T) { + c := caddy.NewTestController("dns", `bind 1.2.3.4`) + err := setupBind(c) + if err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + cfg := dnsserver.GetConfig(c) + if got, want := cfg.ListenHost, "1.2.3.4"; got != want { + t.Errorf("Expected the config's ListenHost to be %s, was %s", want, got) + } +} + +func TestBindAddress(t *testing.T) { + c := caddy.NewTestController("dns", `bind 1.2.3.bla`) + err := setupBind(c) + if err == nil { + t.Fatalf("Expected errors, but got none") + } +} diff --git a/middleware/bind/setup.go b/middleware/bind/setup.go new file mode 100644 index 000000000..c08098b5d --- /dev/null +++ b/middleware/bind/setup.go @@ -0,0 +1,23 @@ +package bind + +import ( + "fmt" + "net" + + "github.com/miekg/coredns/core/dnsserver" + + "github.com/mholt/caddy" +) + +func setupBind(c *caddy.Controller) error { + config := dnsserver.GetConfig(c) + for c.Next() { + if !c.Args(&config.ListenHost) { + return c.ArgErr() + } + } + if net.ParseIP(config.ListenHost) == nil { + return fmt.Errorf("not a valid IP address: %s", config.ListenHost) + } + return nil +} diff --git a/core/setup/cache.go b/middleware/cache/setup.go similarity index 54% rename from core/setup/cache.go rename to middleware/cache/setup.go index f5a1cf0d9..ab7f423f2 100644 --- a/core/setup/cache.go +++ b/middleware/cache/setup.go @@ -1,33 +1,46 @@ -package setup +package cache import ( "strconv" + "github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/cache" + + "github.com/mholt/caddy" ) -// Cache sets up the root file path of the server. -func Cache(c *Controller) (middleware.Middleware, error) { - ttl, zones, err := cacheParse(c) - if err != nil { - return nil, err - } - return func(next middleware.Handler) middleware.Handler { - return cache.NewCache(ttl, zones, next) - }, nil +func init() { + caddy.RegisterPlugin("cache", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) } -func cacheParse(c *Controller) (int, []string, error) { +// Cache sets up the root file path of the server. +func setup(c *caddy.Controller) error { + ttl, zones, err := cacheParse(c) + if err != nil { + return err + } + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + return NewCache(ttl, zones, next) + }) + + return nil +} + +func cacheParse(c *caddy.Controller) (int, []string, error) { var ( - err error - ttl int + err error + ttl int + origins []string ) for c.Next() { if c.Val() == "cache" { // cache [ttl] [zones..] - origins := c.ServerBlockHosts + origins = make([]string, len(c.ServerBlockKeys)) + copy(origins, c.ServerBlockKeys) args := c.RemainingArgs() if len(args) > 0 { origins = args @@ -38,7 +51,7 @@ func cacheParse(c *Controller) (int, []string, error) { origins = origins[1:] if len(origins) == 0 { // There was *only* the ttl, revert back to server block - origins = c.ServerBlockHosts + copy(origins, c.ServerBlockKeys) } } } diff --git a/middleware/chaos/setup.go b/middleware/chaos/setup.go new file mode 100644 index 000000000..8bdb3053e --- /dev/null +++ b/middleware/chaos/setup.go @@ -0,0 +1,50 @@ +package chaos + +import ( + "github.com/miekg/coredns/core/dnsserver" + + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin("chaos", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + version, authors, err := chaosParse(c) + if err != nil { + return err + } + + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + return Chaos{Next: next, Version: version, Authors: authors} + }) + + return nil +} + +func chaosParse(c *caddy.Controller) (string, map[string]bool, error) { + version := "" + authors := make(map[string]bool) + + for c.Next() { + args := c.RemainingArgs() + if len(args) == 0 { + return defaultVersion, nil, nil + } + if len(args) == 1 { + return args[0], nil, nil + } + version = args[0] + for _, a := range args[1:] { + authors[a] = true + } + return version, authors, nil + } + return version, authors, nil +} + +const defaultVersion = "CoreDNS" diff --git a/core/setup/chaos_test.go b/middleware/chaos/setup_test.go similarity index 92% rename from core/setup/chaos_test.go rename to middleware/chaos/setup_test.go index 8431cecef..c1741cdf6 100644 --- a/core/setup/chaos_test.go +++ b/middleware/chaos/setup_test.go @@ -1,12 +1,14 @@ -package setup +package chaos import ( "fmt" "strings" "testing" + + "github.com/mholt/caddy" ) -func TestChaos(t *testing.T) { +func TestSetupChaos(t *testing.T) { tests := []struct { input string shouldErr bool @@ -32,7 +34,7 @@ func TestChaos(t *testing.T) { } for i, test := range tests { - c := NewTestController(test.input) + c := caddy.NewTestController("dns", test.input) version, authors, err := chaosParse(c) if test.shouldErr && err == nil { diff --git a/middleware/dnssec/cache_test.go b/middleware/dnssec/cache_test.go index 0039586d5..3062f99b0 100644 --- a/middleware/dnssec/cache_test.go +++ b/middleware/dnssec/cache_test.go @@ -22,7 +22,7 @@ func TestCacheSet(t *testing.T) { m := testMsg() state := middleware.State{Req: m} k := key(m.Answer) // calculate *before* we add the sig - d := NewDnssec([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil) + d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil) m = d.Sign(state, "miek.nl.", time.Now().UTC()) _, ok := d.get(k) diff --git a/middleware/dnssec/dnssec.go b/middleware/dnssec/dnssec.go index 4907f1621..f517bfe2c 100644 --- a/middleware/dnssec/dnssec.go +++ b/middleware/dnssec/dnssec.go @@ -11,14 +11,15 @@ import ( ) type Dnssec struct { - Next middleware.Handler + Next middleware.Handler + zones []string keys []*DNSKEY inflight *singleflight.Group cache *gcache.Cache } -func NewDnssec(zones []string, keys []*DNSKEY, next middleware.Handler) Dnssec { +func New(zones []string, keys []*DNSKEY, next middleware.Handler) Dnssec { return Dnssec{Next: next, zones: zones, keys: keys, diff --git a/middleware/dnssec/dnssec_test.go b/middleware/dnssec/dnssec_test.go index 49b0d5d3a..10f731325 100644 --- a/middleware/dnssec/dnssec_test.go +++ b/middleware/dnssec/dnssec_test.go @@ -69,7 +69,7 @@ func TestSigningDifferentZone(t *testing.T) { m := testMsgEx() state := middleware.State{Req: m} - d := NewDnssec([]string{"example.org."}, []*DNSKEY{key}, nil) + d := New([]string{"example.org."}, []*DNSKEY{key}, nil) m = d.Sign(state, "example.org.", time.Now().UTC()) if !section(m.Answer, 1) { t.Errorf("answer section should have 1 sig") @@ -158,7 +158,7 @@ func testDelegationMsg() *dns.Msg { func newDnssec(t *testing.T, zones []string) (Dnssec, func(), func()) { k, rm1, rm2 := newKey(t) - d := NewDnssec(zones, []*DNSKEY{k}, nil) + d := New(zones, []*DNSKEY{k}, nil) return d, rm1, rm2 } diff --git a/middleware/dnssec/handler_test.go b/middleware/dnssec/handler_test.go index 6f537b90e..f7cb7e680 100644 --- a/middleware/dnssec/handler_test.go +++ b/middleware/dnssec/handler_test.go @@ -77,7 +77,7 @@ func TestLookupZone(t *testing.T) { dnskey, rm1, rm2 := newKey(t) defer rm1() defer rm2() - dh := NewDnssec([]string{"miek.nl."}, []*DNSKEY{dnskey}, fm) + dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, fm) ctx := context.TODO() for _, tc := range dnsTestCases { @@ -115,7 +115,7 @@ func TestLookupDNSKEY(t *testing.T) { dnskey, rm1, rm2 := newKey(t) defer rm1() defer rm2() - dh := NewDnssec([]string{"miek.nl."}, []*DNSKEY{dnskey}, test.ErrorHandler()) + dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, test.ErrorHandler()) ctx := context.TODO() for _, tc := range dnssecTestCases { diff --git a/core/setup/dnssec.go b/middleware/dnssec/setup.go similarity index 60% rename from core/setup/dnssec.go rename to middleware/dnssec/setup.go index 39f34b66f..999f85bf8 100644 --- a/core/setup/dnssec.go +++ b/middleware/dnssec/setup.go @@ -1,32 +1,43 @@ -package setup +package dnssec import ( "strings" + "github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/dnssec" + + "github.com/mholt/caddy" ) -// Dnssec sets up the dnssec middleware. -func Dnssec(c *Controller) (middleware.Middleware, error) { - zones, keys, err := dnssecParse(c) - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - return dnssec.NewDnssec(zones, keys, next) - }, nil +func init() { + caddy.RegisterPlugin("dnssec", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) } -func dnssecParse(c *Controller) ([]string, []*dnssec.DNSKEY, error) { +func setup(c *caddy.Controller) error { + zones, keys, err := dnssecParse(c) + if err != nil { + return err + } + + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + return New(zones, keys, next) + }) + + return nil +} + +func dnssecParse(c *caddy.Controller) ([]string, []*DNSKEY, error) { zones := []string{} - keys := []*dnssec.DNSKEY{} + keys := []*DNSKEY{} for c.Next() { if c.Val() == "dnssec" { // dnssec [zones...] - zones = c.ServerBlockHosts + zones = make([]string, len(c.ServerBlockKeys)) + copy(zones, c.ServerBlockKeys) args := c.RemainingArgs() if len(args) > 0 { zones = args @@ -47,8 +58,8 @@ func dnssecParse(c *Controller) ([]string, []*dnssec.DNSKEY, error) { return zones, keys, nil } -func keyParse(c *Controller) ([]*dnssec.DNSKEY, error) { - keys := []*dnssec.DNSKEY{} +func keyParse(c *caddy.Controller) ([]*DNSKEY, error) { + keys := []*DNSKEY{} what := c.Val() if !c.NextArg() { @@ -68,7 +79,7 @@ func keyParse(c *Controller) ([]*dnssec.DNSKEY, error) { if strings.HasSuffix(k, ".private") { base = k[:len(k)-8] } - k, err := dnssec.ParseKeyFile(base+".key", base+".private") + k, err := ParseKeyFile(base+".key", base+".private") if err != nil { return nil, err } diff --git a/core/setup/dnssec_test.go b/middleware/dnssec/setup_test.go similarity index 90% rename from core/setup/dnssec_test.go rename to middleware/dnssec/setup_test.go index 364a363bd..9dbeb77fd 100644 --- a/core/setup/dnssec_test.go +++ b/middleware/dnssec/setup_test.go @@ -1,11 +1,13 @@ -package setup +package dnssec import ( "strings" "testing" + + "github.com/mholt/caddy" ) -func TestDnssec(t *testing.T) { +func TestSetupDnssec(t *testing.T) { tests := []struct { input string shouldErr bool @@ -22,7 +24,7 @@ func TestDnssec(t *testing.T) { } for i, test := range tests { - c := NewTestController(test.input) + c := caddy.NewTestController("dns", test.input) zones, keys, err := dnssecParse(c) if test.shouldErr && err == nil { diff --git a/core/setup/errors.go b/middleware/errors/setup.go similarity index 78% rename from core/setup/errors.go rename to middleware/errors/setup.go index bf6b56f87..e1c77373d 100644 --- a/core/setup/errors.go +++ b/middleware/errors/setup.go @@ -1,21 +1,28 @@ -package setup +package errors import ( "io" "log" "os" + "github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/errors" "github.com/hashicorp/go-syslog" + "github.com/mholt/caddy" ) -// Errors configures a new errors middleware instance. -func Errors(c *Controller) (middleware.Middleware, error) { +func init() { + caddy.RegisterPlugin("errors", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { handler, err := errorsParse(c) if err != nil { - return nil, err + return err } var writer io.Writer @@ -30,7 +37,7 @@ func Errors(c *Controller) (middleware.Middleware, error) { case "syslog": writer, err = gsyslog.NewLogger(gsyslog.LOG_ERR, "LOCAL0", "coredns") if err != nil { - return nil, err + return err } default: if handler.LogFile == "" { @@ -41,7 +48,7 @@ func Errors(c *Controller) (middleware.Middleware, error) { var file *os.File file, err = os.OpenFile(handler.LogFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) if err != nil { - return nil, err + return err } if handler.LogRoller != nil { file.Close() @@ -55,14 +62,16 @@ func Errors(c *Controller) (middleware.Middleware, error) { } handler.Log = log.New(writer, "", 0) - return func(next middleware.Handler) middleware.Handler { + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { handler.Next = next return handler - }, nil + }) + + return nil } -func errorsParse(c *Controller) (errors.ErrorHandler, error) { - handler := errors.ErrorHandler{} +func errorsParse(c *caddy.Controller) (ErrorHandler, error) { + handler := ErrorHandler{} optionalBlock := func() (bool, error) { var hadBlock bool @@ -84,7 +93,7 @@ func errorsParse(c *Controller) (errors.ErrorHandler, error) { if c.NextArg() { if c.Val() == "{" { c.IncrNest() - logRoller, err := parseRoller(c) + logRoller, err := middleware.ParseRoller(c) if err != nil { return hadBlock, err } diff --git a/core/setup/errors_test.go b/middleware/errors/setup_test.go similarity index 74% rename from core/setup/errors_test.go rename to middleware/errors/setup_test.go index 42f625f92..6e5a85d08 100644 --- a/core/setup/errors_test.go +++ b/middleware/errors/setup_test.go @@ -1,62 +1,33 @@ -package setup +package errors import ( "testing" + "github.com/mholt/caddy" "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/errors" ) -func TestErrors(t *testing.T) { - c := NewTestController(`errors`) - mid, err := Errors(c) - - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(errors.ErrorHandler) - if !ok { - t.Fatalf("Expected handler to be type ErrorHandler, got: %#v", handler) - } - - if myHandler.LogFile != "" { - t.Errorf("Expected '%s' as the default LogFile", "") - } - if myHandler.LogRoller != nil { - t.Errorf("Expected LogRoller to be nil, got: %v", *myHandler.LogRoller) - } - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } -} - func TestErrorsParse(t *testing.T) { tests := []struct { inputErrorsRules string shouldErr bool - expectedErrorHandler errors.ErrorHandler + expectedErrorHandler ErrorHandler }{ - {`errors`, false, errors.ErrorHandler{ + {`errors`, false, ErrorHandler{ LogFile: "", }}, - {`errors errors.txt`, false, errors.ErrorHandler{ + {`errors errors.txt`, false, ErrorHandler{ LogFile: "errors.txt", }}, - {`errors visible`, false, errors.ErrorHandler{ + {`errors visible`, false, ErrorHandler{ LogFile: "", Debug: true, }}, - {`errors { log visible }`, false, errors.ErrorHandler{ + {`errors { log visible }`, false, ErrorHandler{ LogFile: "", Debug: true, }}, - {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, errors.ErrorHandler{ + {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, ErrorHandler{ LogFile: "errors.txt", LogRoller: &middleware.LogRoller{ MaxSize: 2, @@ -70,7 +41,7 @@ func TestErrorsParse(t *testing.T) { age 11 keep 5 } -}`, false, errors.ErrorHandler{ +}`, false, ErrorHandler{ LogFile: "errors.txt", LogRoller: &middleware.LogRoller{ MaxSize: 3, @@ -81,7 +52,7 @@ func TestErrorsParse(t *testing.T) { }}, } for i, test := range tests { - c := NewTestController(test.inputErrorsRules) + c := caddy.NewTestController("dns", test.inputErrorsRules) actualErrorsRule, err := errorsParse(c) if err == nil && test.shouldErr { diff --git a/middleware/etcd/cname_test.go b/middleware/etcd/cname_test.go index 4a00c05c2..ee341b7b6 100644 --- a/middleware/etcd/cname_test.go +++ b/middleware/etcd/cname_test.go @@ -16,6 +16,8 @@ import ( // Check the ordering of returned cname. func TestCnameLookup(t *testing.T) { + etc := newEtcdMiddleware() + for _, serv := range servicesCname { set(t, etc, serv.Key, 0, serv) defer delete(t, etc, serv.Key) diff --git a/middleware/etcd/debug_test.go b/middleware/etcd/debug_test.go index 91796816f..82de9fe1f 100644 --- a/middleware/etcd/debug_test.go +++ b/middleware/etcd/debug_test.go @@ -30,12 +30,13 @@ func TestIsDebug(t *testing.T) { } func TestDebugLookup(t *testing.T) { + etc := newEtcdMiddleware() + etc.Debug = true + for _, serv := range servicesDebug { set(t, etc, serv.Key, 0, serv) defer delete(t, etc, serv.Key) } - etc.Debug = true - defer func() { etc.Debug = false }() for _, tc := range dnsTestCasesDebug { m := tc.Msg() @@ -69,6 +70,8 @@ func TestDebugLookup(t *testing.T) { } func TestDebugLookupFalse(t *testing.T) { + etc := newEtcdMiddleware() + for _, serv := range servicesDebug { set(t, etc, serv.Key, 0, serv) defer delete(t, etc, serv.Key) diff --git a/middleware/etcd/group_test.go b/middleware/etcd/group_test.go index f5283e3ba..7a2808d45 100644 --- a/middleware/etcd/group_test.go +++ b/middleware/etcd/group_test.go @@ -14,6 +14,8 @@ import ( ) func TestGroupLookup(t *testing.T) { + etc := newEtcdMiddleware() + for _, serv := range servicesGroup { set(t, etc, serv.Key, 0, serv) defer delete(t, etc, serv.Key) diff --git a/middleware/etcd/multi_test.go b/middleware/etcd/multi_test.go index 19e6ca7a3..f4b59f50b 100644 --- a/middleware/etcd/multi_test.go +++ b/middleware/etcd/multi_test.go @@ -14,10 +14,9 @@ import ( ) func TestMultiLookup(t *testing.T) { + etc := newEtcdMiddleware() etc.Zones = []string{"skydns.test.", "miek.nl."} - defer func() { etc.Zones = []string{"skydns.test.", "skydns_extra.test.", "in-addr.arpa."} }() etc.Next = test.ErrorHandler() - defer func() { etc.Next = nil }() for _, serv := range servicesMulti { set(t, etc, serv.Key, 0, serv) diff --git a/middleware/etcd/other_test.go b/middleware/etcd/other_test.go index 34971c6f1..ff37d27d2 100644 --- a/middleware/etcd/other_test.go +++ b/middleware/etcd/other_test.go @@ -18,6 +18,8 @@ import ( ) func TestOtherLookup(t *testing.T) { + etc := newEtcdMiddleware() + for _, serv := range servicesOther { set(t, etc, serv.Key, 0, serv) defer delete(t, etc, serv.Key) diff --git a/middleware/etcd/proxy_lookup_test.go b/middleware/etcd/proxy_lookup_test.go index 9a31eee24..5e0999fb0 100644 --- a/middleware/etcd/proxy_lookup_test.go +++ b/middleware/etcd/proxy_lookup_test.go @@ -15,18 +15,15 @@ import ( ) func TestProxyLookupFailDebug(t *testing.T) { + etc := newEtcdMiddleware() + etc.Proxy = proxy.New([]string{"127.0.0.1:154"}) + etc.Debug = true + for _, serv := range servicesProxy { set(t, etc, serv.Key, 0, serv) defer delete(t, etc, serv.Key) } - prxy := etc.Proxy - etc.Proxy = proxy.New([]string{"127.0.0.1:154"}) - defer func() { etc.Proxy = prxy }() - - etc.Debug = true - defer func() { etc.Debug = false }() - for _, tc := range dnsTestCasesProxy { m := tc.Msg() diff --git a/core/setup/etcd.go b/middleware/etcd/setup.go similarity index 79% rename from core/setup/etcd.go rename to middleware/etcd/setup.go index b90297abd..dc1dddb0e 100644 --- a/core/setup/etcd.go +++ b/middleware/etcd/setup.go @@ -1,4 +1,4 @@ -package setup +package etcd import ( "crypto/tls" @@ -8,39 +8,46 @@ import ( "net/http" "time" + "github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/etcd" "github.com/miekg/coredns/middleware/proxy" "github.com/miekg/coredns/singleflight" etcdc "github.com/coreos/etcd/client" + "github.com/mholt/caddy" "golang.org/x/net/context" ) -const defaultEndpoint = "http://localhost:2379" +func init() { + caddy.RegisterPlugin("etcd", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} -// Etcd sets up the etcd middleware. -func Etcd(c *Controller) (middleware.Middleware, error) { - etcd, stubzones, err := etcdParse(c) +func setup(c *caddy.Controller) error { + e, stubzones, err := etcdParse(c) if err != nil { - return nil, err + return err } if stubzones { - c.Startup = append(c.Startup, func() error { - etcd.UpdateStubZones() + c.OnStartup(func() error { + e.UpdateStubZones() return nil }) } - return func(next middleware.Handler) middleware.Handler { - etcd.Next = next - return etcd - }, nil + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + e.Next = next + return e + }) + + return nil } -func etcdParse(c *Controller) (*etcd.Etcd, bool, error) { +func etcdParse(c *caddy.Controller) (*Etcd, bool, error) { stub := make(map[string]proxy.Proxy) - etc := etcd.Etcd{ + etc := Etcd{ Proxy: proxy.New([]string{"8.8.8.8:53", "8.8.4.4:53"}), PathPrefix: "skydns", Ctx: context.Background(), @@ -60,7 +67,8 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) { etc.Client = client etc.Zones = c.RemainingArgs() if len(etc.Zones) == 0 { - etc.Zones = c.ServerBlockHosts + etc.Zones = make([]string, len(c.ServerBlockKeys)) + copy(etc.Zones, c.ServerBlockKeys) } middleware.Zones(etc.Zones).FullyQualify() if c.NextBlock() { @@ -72,19 +80,19 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) { etc.Debug = true case "path": if !c.NextArg() { - return &etcd.Etcd{}, false, c.ArgErr() + return &Etcd{}, false, c.ArgErr() } etc.PathPrefix = c.Val() case "endpoint": args := c.RemainingArgs() if len(args) == 0 { - return &etcd.Etcd{}, false, c.ArgErr() + return &Etcd{}, false, c.ArgErr() } endpoints = args case "upstream": args := c.RemainingArgs() if len(args) == 0 { - return &etcd.Etcd{}, false, c.ArgErr() + return &Etcd{}, false, c.ArgErr() } for i := 0; i < len(args); i++ { h, p, e := net.SplitHostPort(args[i]) @@ -97,7 +105,7 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) { case "tls": // cert key cacertfile args := c.RemainingArgs() if len(args) != 3 { - return &etcd.Etcd{}, false, c.ArgErr() + return &Etcd{}, false, c.ArgErr() } tlsCertFile, tlsKeyFile, tlsCAcertFile = args[0], args[1], args[2] } @@ -109,19 +117,19 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) { etc.Debug = true case "path": if !c.NextArg() { - return &etcd.Etcd{}, false, c.ArgErr() + return &Etcd{}, false, c.ArgErr() } etc.PathPrefix = c.Val() case "endpoint": args := c.RemainingArgs() if len(args) == 0 { - return &etcd.Etcd{}, false, c.ArgErr() + return &Etcd{}, false, c.ArgErr() } endpoints = args case "upstream": args := c.RemainingArgs() if len(args) == 0 { - return &etcd.Etcd{}, false, c.ArgErr() + return &Etcd{}, false, c.ArgErr() } for i := 0; i < len(args); i++ { h, p, e := net.SplitHostPort(args[i]) @@ -133,7 +141,7 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) { case "tls": // cert key cacertfile args := c.RemainingArgs() if len(args) != 3 { - return &etcd.Etcd{}, false, c.ArgErr() + return &Etcd{}, false, c.ArgErr() } tlsCertFile, tlsKeyFile, tlsCAcertFile = args[0], args[1], args[2] } @@ -141,13 +149,13 @@ func etcdParse(c *Controller) (*etcd.Etcd, bool, error) { } client, err := newEtcdClient(endpoints, tlsCertFile, tlsKeyFile, tlsCAcertFile) if err != nil { - return &etcd.Etcd{}, false, err + return &Etcd{}, false, err } etc.Client = client return &etc, stubzones, nil } } - return &etcd.Etcd{}, false, nil + return &Etcd{}, false, nil } func newEtcdClient(endpoints []string, tlsCert, tlsKey, tlsCACert string) (etcdc.KeysAPI, error) { @@ -195,3 +203,5 @@ func newHTTPSTransport(tlsCertFile, tlsKeyFile, tlsCACertFile string) etcdc.Canc return tr } + +const defaultEndpoint = "http://localhost:2379" diff --git a/middleware/etcd/setup_test.go b/middleware/etcd/setup_test.go index 799b4a1bb..b522345d2 100644 --- a/middleware/etcd/setup_test.go +++ b/middleware/etcd/setup_test.go @@ -19,20 +19,19 @@ import ( "golang.org/x/net/context" ) -var ( - etc *Etcd - client etcdc.KeysAPI - ctxt context.Context -) - func init() { ctxt, _ = context.WithTimeout(context.Background(), etcdTimeout) +} + +// etc *Etcd +func newEtcdMiddleware() *Etcd { + ctxt, _ = context.WithTimeout(context.Background(), etcdTimeout) etcdCfg := etcdc.Config{ Endpoints: []string{"http://localhost:2379"}, } cli, _ := etcdc.New(etcdCfg) - etc = &Etcd{ + return &Etcd{ Proxy: proxy.New([]string{"8.8.8.8:53"}), PathPrefix: "skydns", Ctx: context.Background(), @@ -57,10 +56,12 @@ func delete(t *testing.T, e *Etcd, k string) { } func TestLookup(t *testing.T) { + etc := newEtcdMiddleware() for _, serv := range services { set(t, etc, serv.Key, 0, serv) defer delete(t, etc, serv.Key) } + for _, tc := range dnsTestCases { m := tc.Msg() @@ -91,3 +92,5 @@ func TestLookup(t *testing.T) { } } } + +var ctxt context.Context diff --git a/middleware/etcd/stub_test.go b/middleware/etcd/stub_test.go index 1dc0901c0..b5a101dad 100644 --- a/middleware/etcd/stub_test.go +++ b/middleware/etcd/stub_test.go @@ -41,13 +41,14 @@ func TestStubLookup(t *testing.T) { exampleNetStub := &msg.Service{Host: host, Port: port, Key: "a.example.net.stub.dns.skydns.test."} servicesStub = append(servicesStub, exampleNetStub) + etc := newEtcdMiddleware() + for _, serv := range servicesStub { set(t, etc, serv.Key, 0, serv) defer delete(t, etc, serv.Key) } etc.updateStubZones() - defer func() { etc.Stubmap = nil }() for _, tc := range dnsTestCasesStub { m := tc.Msg() diff --git a/core/setup/file.go b/middleware/file/setup.go similarity index 66% rename from core/setup/file.go rename to middleware/file/setup.go index a0b90c3ca..8b44650ee 100644 --- a/core/setup/file.go +++ b/middleware/file/setup.go @@ -1,24 +1,32 @@ -package setup +package file import ( "fmt" "net" "os" + "github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/file" + + "github.com/mholt/caddy" ) -// File sets up the file middleware. -func File(c *Controller) (middleware.Middleware, error) { +func init() { + caddy.RegisterPlugin("file", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { zones, err := fileParse(c) if err != nil { - return nil, err + return err } // Add startup functions to notify the master(s). for _, n := range zones.Names { - c.Startup = append(c.Startup, func() error { + c.OnStartup(func() error { zones.Z[n].StartupOnce.Do(func() { if len(zones.Z[n].TransferTo) > 0 { zones.Z[n].Notify() @@ -29,24 +37,28 @@ func File(c *Controller) (middleware.Middleware, error) { }) } - return func(next middleware.Handler) middleware.Handler { - return file.File{Next: next, Zones: zones} - }, nil + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + return File{Next: next, Zones: zones} + }) + return nil } -func fileParse(c *Controller) (file.Zones, error) { - z := make(map[string]*file.Zone) +func fileParse(c *caddy.Controller) (Zones, error) { + z := make(map[string]*Zone) names := []string{} + origins := []string{} + for c.Next() { if c.Val() == "file" { // file db.file [zones...] if !c.NextArg() { - return file.Zones{}, c.ArgErr() + return Zones{}, c.ArgErr() } fileName := c.Val() - origins := c.ServerBlockHosts + origins = make([]string, len(c.ServerBlockKeys)) + copy(origins, c.ServerBlockKeys) args := c.RemainingArgs() if len(args) > 0 { origins = args @@ -55,25 +67,25 @@ func fileParse(c *Controller) (file.Zones, error) { reader, err := os.Open(fileName) if err != nil { // bail out - return file.Zones{}, err + return Zones{}, err } for i, _ := range origins { origins[i] = middleware.Host(origins[i]).Normalize() - zone, err := file.Parse(reader, origins[i], fileName) + zone, err := Parse(reader, origins[i], fileName) if err == nil { z[origins[i]] = zone } else { - return file.Zones{}, err + return Zones{}, err } names = append(names, origins[i]) } noReload := false for c.NextBlock() { - t, _, e := transferParse(c) + t, _, e := TransferParse(c) if e != nil { - return file.Zones{}, e + return Zones{}, e } switch c.Val() { case "no_reload": @@ -89,11 +101,12 @@ func fileParse(c *Controller) (file.Zones, error) { } } } - return file.Zones{Z: z, Names: names}, nil + return Zones{Z: z, Names: names}, nil } -// transferParse parses transfer statements: 'transfer to [address...]'. -func transferParse(c *Controller) (tos, froms []string, err error) { +// TransferParse parses transfer statements: 'transfer to [address...]'. +// Exported so secondary can use this as well. +func TransferParse(c *caddy.Controller) (tos, froms []string, err error) { what := c.Val() if !c.NextArg() { return nil, nil, c.ArgErr() diff --git a/middleware/health/health.go b/middleware/health/health.go index 035c9ca7a..1d47e409e 100644 --- a/middleware/health/health.go +++ b/middleware/health/health.go @@ -12,15 +12,16 @@ var once sync.Once type Health struct { Addr string - ln net.Listener - mux *http.ServeMux + + ln net.Listener + mux *http.ServeMux } func health(w http.ResponseWriter, r *http.Request) { io.WriteString(w, ok) } -func (h *Health) Start() error { +func (h *Health) Startup() error { if h.Addr == "" { h.Addr = defAddr } diff --git a/middleware/health/setup.go b/middleware/health/setup.go new file mode 100644 index 000000000..cf7667d17 --- /dev/null +++ b/middleware/health/setup.go @@ -0,0 +1,42 @@ +package health + +import "github.com/mholt/caddy" + +func init() { + caddy.RegisterPlugin("health", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + addr, err := healthParse(c) + if err != nil { + return err + } + + health := &Health{Addr: addr} + c.OnStartup(health.Startup) + c.OnShutdown(health.Shutdown) + + // Don't do AddMiddleware, as health is not *really* a middleware just a separate + // webserver running. + + return nil +} + +func healthParse(c *caddy.Controller) (string, error) { + addr := "" + for c.Next() { + args := c.RemainingArgs() + + switch len(args) { + case 0: + case 1: + addr = args[0] + default: + return "", c.ArgErr() + } + } + return addr, nil +} diff --git a/middleware/kubernetes/controller.go b/middleware/kubernetes/controller.go index 3fbea313e..5de16d61c 100644 --- a/middleware/kubernetes/controller.go +++ b/middleware/kubernetes/controller.go @@ -2,7 +2,6 @@ package kubernetes import ( "fmt" - "log" "sync" "time" @@ -12,7 +11,7 @@ import ( "k8s.io/kubernetes/pkg/client/cache" client "k8s.io/kubernetes/pkg/client/unversioned" "k8s.io/kubernetes/pkg/controller/framework" - "k8s.io/kubernetes/pkg/labels" + "k8s.io/kubernetes/pkg/labels" "k8s.io/kubernetes/pkg/runtime" "k8s.io/kubernetes/pkg/watch" ) @@ -24,7 +23,7 @@ var ( type dnsController struct { client *client.Client - selector *labels.Selector + selector *labels.Selector endpController *framework.Controller svcController *framework.Controller @@ -45,9 +44,9 @@ type dnsController struct { // newDNSController creates a controller for coredns func newdnsController(kubeClient *client.Client, resyncPeriod time.Duration, lselector *labels.Selector) *dnsController { dns := dnsController{ - client: kubeClient, - selector: lselector, - stopCh: make(chan struct{}), + client: kubeClient, + selector: lselector, + stopCh: make(chan struct{}), } dns.endpLister.Store, dns.endpController = framework.NewInformer( @@ -76,54 +75,54 @@ func newdnsController(kubeClient *client.Client, resyncPeriod time.Duration, lse func serviceListFunc(c *client.Client, ns string, s *labels.Selector) func(api.ListOptions) (runtime.Object, error) { return func(opts api.ListOptions) (runtime.Object, error) { - if s != nil { - opts.LabelSelector = *s - } + if s != nil { + opts.LabelSelector = *s + } return c.Services(ns).List(opts) } } func serviceWatchFunc(c *client.Client, ns string, s *labels.Selector) func(options api.ListOptions) (watch.Interface, error) { return func(options api.ListOptions) (watch.Interface, error) { - if s != nil { - options.LabelSelector = *s - } + if s != nil { + options.LabelSelector = *s + } return c.Services(ns).Watch(options) } } func endpointsListFunc(c *client.Client, ns string, s *labels.Selector) func(api.ListOptions) (runtime.Object, error) { return func(opts api.ListOptions) (runtime.Object, error) { - if s != nil { - opts.LabelSelector = *s - } + if s != nil { + opts.LabelSelector = *s + } return c.Endpoints(ns).List(opts) } } func endpointsWatchFunc(c *client.Client, ns string, s *labels.Selector) func(options api.ListOptions) (watch.Interface, error) { return func(options api.ListOptions) (watch.Interface, error) { - if s != nil { - options.LabelSelector = *s - } + if s != nil { + options.LabelSelector = *s + } return c.Endpoints(ns).Watch(options) } } func namespaceListFunc(c *client.Client, s *labels.Selector) func(api.ListOptions) (runtime.Object, error) { return func(opts api.ListOptions) (runtime.Object, error) { - if s != nil { - opts.LabelSelector = *s - } + if s != nil { + opts.LabelSelector = *s + } return c.Namespaces().List(opts) } } func namespaceWatchFunc(c *client.Client, s *labels.Selector) func(options api.ListOptions) (watch.Interface, error) { return func(options api.ListOptions) (watch.Interface, error) { - if s != nil { - options.LabelSelector = *s - } + if s != nil { + options.LabelSelector = *s + } return c.Namespaces().Watch(options) } } @@ -140,7 +139,6 @@ func (dns *dnsController) Stop() error { // Only try draining the workqueue if we haven't already. if !dns.shutdown { close(dns.stopCh) - log.Println("shutting down controller queues") dns.shutdown = true return nil @@ -151,14 +149,10 @@ func (dns *dnsController) Stop() error { // Run starts the controller. func (dns *dnsController) Run() { - log.Println("[debug] Starting k8s notification controllers") - go dns.endpController.Run(dns.stopCh) go dns.svcController.Run(dns.stopCh) go dns.nsController.Run(dns.stopCh) - <-dns.stopCh - log.Println("[debug] shutting down coredns controller") } func (dns *dnsController) GetNamespaceList() *api.NamespaceList { @@ -203,12 +197,12 @@ func (dns *dnsController) GetServiceInNamespace(namespace string, servicename st svcObj, svcExists, err := dns.svcLister.Store.GetByKey(svcKey) if err != nil { - log.Printf("error getting service %v from the cache: %v\n", svcKey, err) + // TODO(...): should return err here return nil } if !svcExists { - log.Printf("service %v does not exists\n", svcKey) + // TODO(...): should return err here return nil } diff --git a/middleware/kubernetes/handler.go b/middleware/kubernetes/handler.go index 05dfba934..1986820d5 100644 --- a/middleware/kubernetes/handler.go +++ b/middleware/kubernetes/handler.go @@ -2,7 +2,6 @@ package kubernetes import ( "fmt" - "log" "strings" "github.com/miekg/coredns/middleware" @@ -12,8 +11,6 @@ import ( ) func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - log.Printf("[debug] here entering ServeDNS: ctx:%v dnsmsg:%v\n", ctx, r) - state := middleware.State{W: w, Req: r} if state.QClass() != dns.ClassINET { return dns.RcodeServerFailure, fmt.Errorf("can only deal with ClassINET") diff --git a/middleware/kubernetes/kubernetes.go b/middleware/kubernetes/kubernetes.go index 5e2a1bf53..59a044140 100644 --- a/middleware/kubernetes/kubernetes.go +++ b/middleware/kubernetes/kubernetes.go @@ -16,10 +16,10 @@ import ( "github.com/miekg/dns" "k8s.io/kubernetes/pkg/api" unversionedapi "k8s.io/kubernetes/pkg/api/unversioned" - "k8s.io/kubernetes/pkg/labels" unversionedclient "k8s.io/kubernetes/pkg/client/unversioned" "k8s.io/kubernetes/pkg/client/unversioned/clientcmd" clientcmdapi "k8s.io/kubernetes/pkg/client/unversioned/clientcmd/api" + "k8s.io/kubernetes/pkg/labels" ) type Kubernetes struct { @@ -32,10 +32,10 @@ type Kubernetes struct { NameTemplate *nametemplate.NameTemplate Namespaces []string LabelSelector *unversionedapi.LabelSelector - Selector *labels.Selector + Selector *labels.Selector } -func (g *Kubernetes) StartKubeCache() error { +func (g *Kubernetes) InitKubeCache() error { // For a custom api server or running outside a k8s cluster // set URL in env.KUBERNETES_MASTER or set endpoint in Corefile loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() @@ -46,7 +46,6 @@ func (g *Kubernetes) StartKubeCache() error { clientConfig := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, overrides) config, err := clientConfig.ClientConfig() if err != nil { - log.Printf("[debug] error connecting to the client: %v", err) return err } kubeClient, err := unversionedclient.New(config) @@ -58,20 +57,17 @@ func (g *Kubernetes) StartKubeCache() error { if g.LabelSelector == nil { log.Printf("[INFO] Kubernetes middleware configured without a label selector. No label-based filtering will be performed.") } else { - var selector labels.Selector + var selector labels.Selector selector, err = unversionedapi.LabelSelectorAsSelector(g.LabelSelector) - g.Selector = &selector - if err != nil { - log.Printf("[ERROR] Unable to create Selector for LabelSelector '%s'.Error was: %s", g.LabelSelector, err) - return err - } + g.Selector = &selector + if err != nil { + log.Printf("[ERROR] Unable to create Selector for LabelSelector '%s'.Error was: %s", g.LabelSelector, err) + return err + } log.Printf("[INFO] Kubernetes middleware configured with the label selector '%s'. Only kubernetes objects matching this label selector will be exposed.", unversionedapi.FormatLabelSelector(g.LabelSelector)) } - log.Printf("[debug] Starting kubernetes middleware with k8s API resync period: %s", g.ResyncPeriod) g.APIConn = newdnsController(kubeClient, g.ResyncPeriod, g.Selector) - go g.APIConn.Run() - return err } @@ -115,7 +111,6 @@ func (g *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) { typeName string ) - log.Printf("[debug] enter Records('%v', '%v')\n", name, exact) zone, serviceSegments := g.getZoneForName(name) // TODO: Implementation above globbed together segments for the serviceName if @@ -137,30 +132,18 @@ func (g *Kubernetes) Records(name string, exact bool) ([]msg.Service, error) { serviceName = util.WildcardStar } - log.Printf("[debug] published namespaces: %v\n", g.Namespaces) - - log.Printf("[debug] exact: %v\n", exact) - log.Printf("[debug] zone: %v\n", zone) - log.Printf("[debug] servicename: %v\n", serviceName) - log.Printf("[debug] namespace: %v\n", namespace) - log.Printf("[debug] typeName: %v\n", typeName) - log.Printf("[debug] APIconn: %v\n", g.APIConn) - nsWildcard := util.SymbolContainsWildcard(namespace) serviceWildcard := util.SymbolContainsWildcard(serviceName) // Abort if the namespace does not contain a wildcard, and namespace is not published per CoreFile // Case where namespace contains a wildcard is handled in Get(...) method. if (!nsWildcard) && (len(g.Namespaces) > 0) && (!util.StringInSlice(namespace, g.Namespaces)) { - log.Printf("[debug] Namespace '%v' is not published by Corefile\n", namespace) return nil, nil } - log.Printf("before g.Get(namespace, nsWildcard, serviceName, serviceWildcard): %v %v %v %v", namespace, nsWildcard, serviceName, serviceWildcard) + log.Printf("[debug] before g.Get(namespace, nsWildcard, serviceName, serviceWildcard): %v %v %v %v", namespace, nsWildcard, serviceName, serviceWildcard) k8sItems, err := g.Get(namespace, nsWildcard, serviceName, serviceWildcard) - log.Printf("[debug] k8s items: %v\n", k8sItems) if err != nil { - log.Printf("[ERROR] Got error while looking up ServiceItems. Error is: %v\n", err) return nil, err } if k8sItems == nil { @@ -178,7 +161,6 @@ func (g *Kubernetes) getRecordsForServiceItems(serviceItems []api.Service, value for _, item := range serviceItems { clusterIP := item.Spec.ClusterIP - log.Printf("[debug] clusterIP: %v\n", clusterIP) // Create records by constructing record name from template... //values.Namespace = item.Metadata.Namespace @@ -188,13 +170,11 @@ func (g *Kubernetes) getRecordsForServiceItems(serviceItems []api.Service, value // Create records for each exposed port... for _, p := range item.Spec.Ports { - log.Printf("[debug] port: %v\n", p.Port) s := msg.Service{Host: clusterIP, Port: int(p.Port)} records = append(records, s) } } - log.Printf("[debug] records from getRecordsForServiceItems(): %v\n", records) return records } @@ -202,13 +182,6 @@ func (g *Kubernetes) getRecordsForServiceItems(serviceItems []api.Service, value func (g *Kubernetes) Get(namespace string, nsWildcard bool, servicename string, serviceWildcard bool) ([]api.Service, error) { serviceList := g.APIConn.GetServiceList() - /* TODO: Remove? - if err != nil { - log.Printf("[ERROR] Getting service list produced error: %v", err) - return nil, err - } - */ - var resultItems []api.Service for _, item := range serviceList.Items { @@ -216,7 +189,6 @@ func (g *Kubernetes) Get(namespace string, nsWildcard bool, servicename string, // If namespace has a wildcard, filter results against Corefile namespace list. // (Namespaces without a wildcard were filtered before the call to this function.) if nsWildcard && (len(g.Namespaces) > 0) && (!util.StringInSlice(item.Namespace, g.Namespaces)) { - log.Printf("[debug] Namespace '%v' is not published by Corefile\n", item.Namespace) continue } resultItems = append(resultItems, item) diff --git a/middleware/kubernetes/nametemplate/nametemplate.go b/middleware/kubernetes/nametemplate/nametemplate.go index 5a34ae4ad..3e1ac4bb3 100644 --- a/middleware/kubernetes/nametemplate/nametemplate.go +++ b/middleware/kubernetes/nametemplate/nametemplate.go @@ -2,7 +2,6 @@ package nametemplate import ( "errors" - "log" "strings" "github.com/miekg/coredns/middleware/kubernetes/util" @@ -87,17 +86,13 @@ func (t *NameTemplate) SetTemplate(s string) error { if !elementPositionSet { if strings.Contains(v, "{") { err = errors.New("Record name template contains the unknown symbol '" + v + "'") - log.Printf("[debug] %v\n", err) return err - } else { - log.Printf("[debug] Template string has static element '%v'\n", v) } } } if err == nil && !t.IsValid() { err = errors.New("Record name template does not pass NameTemplate validation") - log.Printf("[debug] %v\n", err) return err } diff --git a/core/setup/kubernetes.go b/middleware/kubernetes/setup.go similarity index 57% rename from core/setup/kubernetes.go rename to middleware/kubernetes/setup.go index 7439a9f1b..fc3c036b8 100644 --- a/core/setup/kubernetes.go +++ b/middleware/kubernetes/setup.go @@ -1,70 +1,78 @@ -package setup +package kubernetes import ( "errors" "fmt" - "log" "strings" "time" + "github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/kubernetes" "github.com/miekg/coredns/middleware/kubernetes/nametemplate" + + "github.com/mholt/caddy" unversionedapi "k8s.io/kubernetes/pkg/api/unversioned" ) -const ( - defaultNameTemplate = "{service}.{namespace}.{zone}" - defaultResyncPeriod = 5 * time.Minute -) - -// Kubernetes sets up the kubernetes middleware. -func Kubernetes(c *Controller) (middleware.Middleware, error) { - kubernetes, err := kubernetesParse(c) - if err != nil { - return nil, err - } - - err = kubernetes.StartKubeCache() - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - kubernetes.Next = next - return kubernetes - }, nil +func init() { + caddy.RegisterPlugin("kubernetes", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) } -func kubernetesParse(c *Controller) (kubernetes.Kubernetes, error) { +func setup(c *caddy.Controller) error { + kubernetes, err := kubernetesParse(c) + if err != nil { + return err + } + + err = kubernetes.InitKubeCache() + if err != nil { + return err + } + + // Register KubeCache start and stop functions with Caddy + c.OnStartup(func() error { + go kubernetes.APIConn.Run() + return nil + }) + + c.OnShutdown(func() error { + return kubernetes.APIConn.Stop() + }) + + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + kubernetes.Next = next + return kubernetes + }) + + return nil +} + +func kubernetesParse(c *caddy.Controller) (Kubernetes, error) { var err error template := defaultNameTemplate - k8s := kubernetes.Kubernetes{ - ResyncPeriod: defaultResyncPeriod, - } + k8s := Kubernetes{ResyncPeriod: defaultResyncPeriod} k8s.NameTemplate = new(nametemplate.NameTemplate) k8s.NameTemplate.SetTemplate(template) - // TODO: expose resync period in Corefile - for c.Next() { if c.Val() == "kubernetes" { zones := c.RemainingArgs() if len(zones) == 0 { - k8s.Zones = c.ServerBlockHosts - log.Printf("[debug] Zones(from ServerBlockHosts): %v", zones) - } else { - // Normalize requested zones - k8s.Zones = kubernetes.NormalizeZoneList(zones) + k8s.Zones = make([]string, len(c.ServerBlockKeys)) + copy(k8s.Zones, c.ServerBlockKeys) } + k8s.Zones = NormalizeZoneList(zones) middleware.Zones(k8s.Zones).FullyQualify() + if k8s.Zones == nil || len(k8s.Zones) < 1 { err = errors.New("Zone name must be provided for kubernetes middleware.") - log.Printf("[debug] %v\n", err) - return kubernetes.Kubernetes{}, err + return Kubernetes{}, err } for c.NextBlock() { @@ -75,27 +83,24 @@ func kubernetesParse(c *Controller) (kubernetes.Kubernetes, error) { template := strings.Join(args, "") err = k8s.NameTemplate.SetTemplate(template) if err != nil { - return kubernetes.Kubernetes{}, err + return Kubernetes{}, err } } else { - log.Printf("[debug] 'template' keyword provided without any template value.") - return kubernetes.Kubernetes{}, c.ArgErr() + return Kubernetes{}, c.ArgErr() } case "namespaces": args := c.RemainingArgs() if len(args) != 0 { k8s.Namespaces = append(k8s.Namespaces, args...) } else { - log.Printf("[debug] 'namespaces' keyword provided without any namespace values.") - return kubernetes.Kubernetes{}, c.ArgErr() + return Kubernetes{}, c.ArgErr() } case "endpoint": args := c.RemainingArgs() if len(args) != 0 { k8s.APIEndpoint = args[0] } else { - log.Printf("[debug] 'endpoint' keyword provided without any endpoint url value.") - return kubernetes.Kubernetes{}, c.ArgErr() + return Kubernetes{}, c.ArgErr() } case "resyncperiod": args := c.RemainingArgs() @@ -103,12 +108,10 @@ func kubernetesParse(c *Controller) (kubernetes.Kubernetes, error) { k8s.ResyncPeriod, err = time.ParseDuration(args[0]) if err != nil { err = errors.New(fmt.Sprintf("Unable to parse resync duration value. Value provided was '%v'. Example valid values: '15s', '5m', '1h'. Error was: %v", args[0], err)) - log.Printf("[ERROR] %v", err) - return kubernetes.Kubernetes{}, err + return Kubernetes{}, err } } else { - log.Printf("[debug] 'resyncperiod' keyword provided without any duration value.") - return kubernetes.Kubernetes{}, c.ArgErr() + return Kubernetes{}, c.ArgErr() } case "labels": args := c.RemainingArgs() @@ -117,12 +120,10 @@ func kubernetesParse(c *Controller) (kubernetes.Kubernetes, error) { k8s.LabelSelector, err = unversionedapi.ParseToLabelSelector(labelSelectorString) if err != nil { err = errors.New(fmt.Sprintf("Unable to parse label selector. Value provided was '%v'. Error was: %v", labelSelectorString, err)) - log.Printf("[ERROR] %v", err) - return kubernetes.Kubernetes{}, err + return Kubernetes{}, err } } else { - log.Printf("[debug] 'labels' keyword provided without any selector value.") - return kubernetes.Kubernetes{}, c.ArgErr() + return Kubernetes{}, c.ArgErr() } } } @@ -130,6 +131,10 @@ func kubernetesParse(c *Controller) (kubernetes.Kubernetes, error) { } } err = errors.New("Kubernetes setup called without keyword 'kubernetes' in Corefile") - log.Printf("[ERROR] %v\n", err) - return kubernetes.Kubernetes{}, err + return Kubernetes{}, err } + +const ( + defaultNameTemplate = "{service}.{namespace}.{zone}" + defaultResyncPeriod = 5 * time.Minute +) diff --git a/core/setup/kubernetes_test.go b/middleware/kubernetes/setup_test.go similarity index 98% rename from core/setup/kubernetes_test.go rename to middleware/kubernetes/setup_test.go index cf6ac9abc..6e0918a3b 100644 --- a/core/setup/kubernetes_test.go +++ b/middleware/kubernetes/setup_test.go @@ -1,10 +1,11 @@ -package setup +package kubernetes import ( "strings" "testing" "time" + "github.com/mholt/caddy" unversionedapi "k8s.io/kubernetes/pkg/api/unversioned" ) @@ -320,7 +321,7 @@ func TestKubernetesParse(t *testing.T) { t.Logf("Parser test cases count: %v", len(tests)) for i, test := range tests { - c := NewTestController(test.input) + c := caddy.NewTestController("dns", test.input) k8sController, err := kubernetesParse(c) t.Logf("setup test: %2v -- %v\n", i, test.description) //t.Logf("controller: %v\n", k8sController) diff --git a/middleware/loadbalance/setup.go b/middleware/loadbalance/setup.go new file mode 100644 index 000000000..ef3f35f03 --- /dev/null +++ b/middleware/loadbalance/setup.go @@ -0,0 +1,25 @@ +package loadbalance + +import ( + "github.com/mholt/caddy" + "github.com/miekg/coredns/core/dnsserver" +) + +func init() { + caddy.RegisterPlugin("loadbalance", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + for c.Next() { + // TODO(miek): block and option parsing + } + + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + return RoundRobin{Next: next} + }) + + return nil +} diff --git a/core/setup/log.go b/middleware/log/setup.go similarity index 70% rename from core/setup/log.go rename to middleware/log/setup.go index 1deca2565..a80e69ac3 100644 --- a/core/setup/log.go +++ b/middleware/log/setup.go @@ -1,26 +1,34 @@ -package setup +package log import ( "io" "log" "os" - "github.com/hashicorp/go-syslog" + "github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/middleware" - corednslog "github.com/miekg/coredns/middleware/log" "github.com/miekg/coredns/server" + + "github.com/hashicorp/go-syslog" + "github.com/mholt/caddy" "github.com/miekg/dns" ) -// Log sets up the logging middleware. -func Log(c *Controller) (middleware.Middleware, error) { +func init() { + caddy.RegisterPlugin("log", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { rules, err := logParse(c) if err != nil { - return nil, err + return err } // Open the log files for writing when the server starts - c.Startup = append(c.Startup, func() error { + c.OnStartup(func() error { for i := 0; i < len(rules); i++ { var err error var writer io.Writer @@ -55,13 +63,15 @@ func Log(c *Controller) (middleware.Middleware, error) { return nil }) - return func(next middleware.Handler) middleware.Handler { - return corednslog.Logger{Next: next, Rules: rules, ErrorFunc: server.DefaultErrorFunc} - }, nil + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + return Logger{Next: next, Rules: rules, ErrorFunc: server.DefaultErrorFunc} + }) + + return nil } -func logParse(c *Controller) ([]corednslog.Rule, error) { - var rules []corednslog.Rule +func logParse(c *caddy.Controller) ([]Rule, error) { + var rules []Rule for c.Next() { args := c.RemainingArgs() @@ -72,7 +82,7 @@ func logParse(c *Controller) ([]corednslog.Rule, error) { if c.NextArg() { if c.Val() == "{" { var err error - logRoller, err = parseRoller(c) + logRoller, err = middleware.ParseRoller(c) if err != nil { return nil, err } @@ -88,37 +98,37 @@ func logParse(c *Controller) ([]corednslog.Rule, error) { } if len(args) == 0 { // Nothing specified; use defaults - rules = append(rules, corednslog.Rule{ + rules = append(rules, Rule{ NameScope: ".", - OutputFile: corednslog.DefaultLogFilename, - Format: corednslog.DefaultLogFormat, + OutputFile: DefaultLogFilename, + Format: DefaultLogFormat, Roller: logRoller, }) } else if len(args) == 1 { // Only an output file specified - rules = append(rules, corednslog.Rule{ + rules = append(rules, Rule{ NameScope: ".", OutputFile: args[0], - Format: corednslog.DefaultLogFormat, + Format: DefaultLogFormat, Roller: logRoller, }) } else { // Name scope, output file, and maybe a format specified - format := corednslog.DefaultLogFormat + format := DefaultLogFormat if len(args) > 2 { switch args[2] { case "{common}": - format = corednslog.CommonLogFormat + format = CommonLogFormat case "{combined}": - format = corednslog.CombinedLogFormat + format = CombinedLogFormat default: format = args[2] } } - rules = append(rules, corednslog.Rule{ + rules = append(rules, Rule{ NameScope: dns.Fqdn(args[0]), OutputFile: args[1], Format: format, diff --git a/core/setup/log_test.go b/middleware/log/setup_test.go similarity index 65% rename from core/setup/log_test.go rename to middleware/log/setup_test.go index ad9cb7c3a..0a3ee63fe 100644 --- a/core/setup/log_test.go +++ b/middleware/log/setup_test.go @@ -1,99 +1,61 @@ -package setup +package log import ( "testing" "github.com/miekg/coredns/middleware" - corednslog "github.com/miekg/coredns/middleware/log" + + "github.com/mholt/caddy" ) -func TestLog(t *testing.T) { - - c := NewTestController(`log`) - - mid, err := Log(c) - - if err != nil { - t.Errorf("Expected no errors, got: %v", err) - } - - if mid == nil { - t.Fatal("Expected middleware, was nil instead") - } - - handler := mid(EmptyNext) - myHandler, ok := handler.(corednslog.Logger) - - if !ok { - t.Fatalf("Expected handler to be type Logger, got: %#v", handler) - } - - if myHandler.Rules[0].NameScope != "." { - t.Errorf("Expected . as the default NameScope") - } - if myHandler.Rules[0].OutputFile != corednslog.DefaultLogFilename { - t.Errorf("Expected %s as the default OutputFile", corednslog.DefaultLogFilename) - } - if myHandler.Rules[0].Format != corednslog.DefaultLogFormat { - t.Errorf("Expected %s as the default Log Format", corednslog.DefaultLogFormat) - } - if myHandler.Rules[0].Roller != nil { - t.Errorf("Expected Roller to be nil, got: %v", *myHandler.Rules[0].Roller) - } - if !SameNext(myHandler.Next, EmptyNext) { - t.Error("'Next' field of handler was not set properly") - } - -} - func TestLogParse(t *testing.T) { tests := []struct { inputLogRules string shouldErr bool - expectedLogRules []corednslog.Rule + expectedLogRules []Rule }{ - {`log`, false, []corednslog.Rule{{ + {`log`, false, []Rule{{ NameScope: ".", - OutputFile: corednslog.DefaultLogFilename, - Format: corednslog.DefaultLogFormat, + OutputFile: DefaultLogFilename, + Format: DefaultLogFormat, }}}, - {`log log.txt`, false, []corednslog.Rule{{ + {`log log.txt`, false, []Rule{{ NameScope: ".", OutputFile: "log.txt", - Format: corednslog.DefaultLogFormat, + Format: DefaultLogFormat, }}}, - {`log example.org log.txt`, false, []corednslog.Rule{{ + {`log example.org log.txt`, false, []Rule{{ NameScope: "example.org.", OutputFile: "log.txt", - Format: corednslog.DefaultLogFormat, + Format: DefaultLogFormat, }}}, - {`log example.org. stdout`, false, []corednslog.Rule{{ + {`log example.org. stdout`, false, []Rule{{ NameScope: "example.org.", OutputFile: "stdout", - Format: corednslog.DefaultLogFormat, + Format: DefaultLogFormat, }}}, - {`log example.org log.txt {common}`, false, []corednslog.Rule{{ + {`log example.org log.txt {common}`, false, []Rule{{ NameScope: "example.org.", OutputFile: "log.txt", - Format: corednslog.CommonLogFormat, + Format: CommonLogFormat, }}}, - {`log example.org accesslog.txt {combined}`, false, []corednslog.Rule{{ + {`log example.org accesslog.txt {combined}`, false, []Rule{{ NameScope: "example.org.", OutputFile: "accesslog.txt", - Format: corednslog.CombinedLogFormat, + Format: CombinedLogFormat, }}}, {`log example.org. log.txt - log example.net accesslog.txt {combined}`, false, []corednslog.Rule{{ + log example.net accesslog.txt {combined}`, false, []Rule{{ NameScope: "example.org.", OutputFile: "log.txt", - Format: corednslog.DefaultLogFormat, + Format: DefaultLogFormat, }, { NameScope: "example.net.", OutputFile: "accesslog.txt", - Format: corednslog.CombinedLogFormat, + Format: CombinedLogFormat, }}}, {`log example.org stdout {host} - log example.org log.txt {when}`, false, []corednslog.Rule{{ + log example.org log.txt {when}`, false, []Rule{{ NameScope: "example.org.", OutputFile: "stdout", Format: "{host}", @@ -102,10 +64,10 @@ func TestLogParse(t *testing.T) { OutputFile: "log.txt", Format: "{when}", }}}, - {`log access.log { rotate { size 2 age 10 keep 3 } }`, false, []corednslog.Rule{{ + {`log access.log { rotate { size 2 age 10 keep 3 } }`, false, []Rule{{ NameScope: ".", OutputFile: "access.log", - Format: corednslog.DefaultLogFormat, + Format: DefaultLogFormat, Roller: &middleware.LogRoller{ MaxSize: 2, MaxAge: 10, @@ -115,7 +77,7 @@ func TestLogParse(t *testing.T) { }}}, } for i, test := range tests { - c := NewTestController(test.inputLogRules) + c := caddy.NewTestController("dns", test.inputLogRules) actualLogRules, err := logParse(c) if err == nil && test.shouldErr { diff --git a/middleware/metrics/metrics.go b/middleware/metrics/metrics.go index 1c7db29d2..3a72bc3bb 100644 --- a/middleware/metrics/metrics.go +++ b/middleware/metrics/metrics.go @@ -34,7 +34,7 @@ type Metrics struct { ZoneNames []string } -func (m *Metrics) Start() error { +func (m *Metrics) Startup() error { m.Once.Do(func() { define() diff --git a/middleware/metrics/setup.go b/middleware/metrics/setup.go new file mode 100644 index 000000000..f31cbd4d5 --- /dev/null +++ b/middleware/metrics/setup.go @@ -0,0 +1,84 @@ +package metrics + +import ( + "sync" + + "github.com/miekg/coredns/core/dnsserver" + "github.com/miekg/coredns/middleware" + + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin("prometheus", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + m, err := prometheusParse(c) + if err != nil { + return err + } + + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + m.Next = next + return m + }) + + metricsOnce.Do(func() { + c.OnStartup(m.Startup) + c.OnShutdown(m.Shutdown) + }) + + return nil +} + +func prometheusParse(c *caddy.Controller) (Metrics, error) { + var ( + met Metrics + err error + ) + + for c.Next() { + if len(met.ZoneNames) > 0 { + return Metrics{}, c.Err("metrics: can only have one metrics module per server") + } + met.ZoneNames = make([]string, len(c.ServerBlockKeys)) + copy(met.ZoneNames, c.ServerBlockKeys) + for i, _ := range met.ZoneNames { + met.ZoneNames[i] = middleware.Host(met.ZoneNames[i]).Normalize() + } + args := c.RemainingArgs() + + switch len(args) { + case 0: + case 1: + met.Addr = args[0] + default: + return Metrics{}, c.ArgErr() + } + for c.NextBlock() { + switch c.Val() { + case "address": + args = c.RemainingArgs() + if len(args) != 1 { + return Metrics{}, c.ArgErr() + } + met.Addr = args[0] + default: + return Metrics{}, c.Errf("metrics: unknown item: %s", c.Val()) + } + + } + } + if met.Addr == "" { + met.Addr = addr + } + return met, err +} + +var metricsOnce sync.Once + +const addr = "localhost:9153" diff --git a/middleware/pprof/pprof.go b/middleware/pprof/pprof.go index 42102d198..f538b3091 100644 --- a/middleware/pprof/pprof.go +++ b/middleware/pprof/pprof.go @@ -12,7 +12,7 @@ type Handler struct { mux *http.ServeMux } -func (h *Handler) Start() error { +func (h *Handler) Startup() error { if ln, err := net.Listen("tcp", addr); err != nil { log.Printf("[ERROR] Failed to start pprof handler: %s", err) return err diff --git a/middleware/pprof/setup.go b/middleware/pprof/setup.go new file mode 100644 index 000000000..2b6701f35 --- /dev/null +++ b/middleware/pprof/setup.go @@ -0,0 +1,40 @@ +package pprof + +import ( + "sync" + + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin("pprof", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + found := false + for c.Next() { + if found { + return c.Err("pprof can only be specified once") + } + if len(c.RemainingArgs()) != 0 { + return c.ArgErr() + } + if c.NextBlock() { + return c.ArgErr() + } + found = true + } + + handler := &Handler{} + pprofOnce.Do(func() { + c.OnStartup(handler.Startup) + c.OnShutdown(handler.Shutdown) + }) + + return nil +} + +var pprofOnce sync.Once diff --git a/core/setup/pprof_test.go b/middleware/pprof/setup_test.go similarity index 78% rename from core/setup/pprof_test.go rename to middleware/pprof/setup_test.go index ac9375af7..af46fd415 100644 --- a/core/setup/pprof_test.go +++ b/middleware/pprof/setup_test.go @@ -1,6 +1,10 @@ -package setup +package pprof -import "testing" +import ( + "testing" + + "github.com/mholt/caddy" +) func TestPProf(t *testing.T) { tests := []struct { @@ -17,8 +21,8 @@ func TestPProf(t *testing.T) { pprof`, true}, } for i, test := range tests { - c := NewTestController(test.input) - _, err := PProf(c) + c := caddy.NewTestController("dns", test.input) + err := setup(c) if test.shouldErr && err == nil { t.Errorf("Test %v: Expected error but found nil", i) } else if !test.shouldErr && err != nil { diff --git a/middleware/proxy/setup.go b/middleware/proxy/setup.go new file mode 100644 index 000000000..81dc7777c --- /dev/null +++ b/middleware/proxy/setup.go @@ -0,0 +1,26 @@ +package proxy + +import ( + "github.com/miekg/coredns/core/dnsserver" + + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin("proxy", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + upstreams, err := NewStaticUpstreams(c.Dispenser) + if err != nil { + return err + } + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + return Proxy{Next: next, Client: Clients(), Upstreams: upstreams} + }) + + return nil +} diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index 76c27a45a..12ec00d76 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -11,8 +11,9 @@ import ( "sync/atomic" "time" - "github.com/miekg/coredns/core/parse" "github.com/miekg/coredns/middleware" + + "github.com/mholt/caddy/caddyfile" "github.com/miekg/dns" ) @@ -43,7 +44,7 @@ type Options struct { // NewStaticUpstreams parses the configuration input and sets up // static upstreams for the proxy middleware. -func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { +func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) { var upstreams []Upstream for c.Next() { upstream := &staticUpstream{ @@ -73,7 +74,7 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { } for c.NextBlock() { - if err := parseBlock(&c, upstream); err != nil { + if err := parseBlock(c, upstream); err != nil { return upstreams, err } } @@ -125,7 +126,7 @@ func (u *staticUpstream) Options() Options { return u.options } -func parseBlock(c *parse.Dispenser, u *staticUpstream) error { +func parseBlock(c caddyfile.Dispenser, u *staticUpstream) error { switch c.Val() { case "policy": if !c.NextArg() { diff --git a/core/setup/rewrite.go b/middleware/rewrite/setup.go similarity index 70% rename from core/setup/rewrite.go rename to middleware/rewrite/setup.go index 86bef2ca3..abfc0fbd6 100644 --- a/core/setup/rewrite.go +++ b/middleware/rewrite/setup.go @@ -1,34 +1,40 @@ -package setup +package rewrite import ( "strconv" "strings" - "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/middleware/rewrite" + "github.com/miekg/coredns/core/dnsserver" + + "github.com/mholt/caddy" ) -// Rewrite configures a new Rewrite middleware instance. -func Rewrite(c *Controller) (middleware.Middleware, error) { - rewrites, err := rewriteParse(c) - if err != nil { - return nil, err - } - - return func(next middleware.Handler) middleware.Handler { - return rewrite.Rewrite{ - Next: next, - Rules: rewrites, - } - }, nil +func init() { + caddy.RegisterPlugin("rewrite", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) } -func rewriteParse(c *Controller) ([]rewrite.Rule, error) { - var simpleRules []rewrite.Rule - var regexpRules []rewrite.Rule +func setup(c *caddy.Controller) error { + rewrites, err := rewriteParse(c) + if err != nil { + return err + } + + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + return Rewrite{Next: next, Rules: rewrites} + }) + + return nil +} + +func rewriteParse(c *caddy.Controller) ([]Rule, error) { + var simpleRules []Rule + var regexpRules []Rule for c.Next() { - var rule rewrite.Rule + var rule Rule var err error var base = "." var pattern, to string @@ -37,7 +43,7 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { args := c.RemainingArgs() - var ifs []rewrite.If + var ifs []If switch len(args) { case 1: @@ -68,7 +74,7 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { if len(args1) != 3 { return nil, c.ArgErr() } - ifCond, err := rewrite.NewIf(args1[0], args1[1], args1[2]) + ifCond, err := NewIf(args1[0], args1[1], args1[2]) if err != nil { return nil, err } @@ -92,14 +98,14 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { // TODO(miek): complex rules base, pattern, to, status, ext, ifs = base, pattern, to, status, ext, ifs err = err - // if rule, err = rewrite.NewComplexRule(base, pattern, to, status, ext, ifs); err != nil { + // if rule, err = NewComplexRule(base, pattern, to, status, ext, ifs); err != nil { // return nil, err // } regexpRules = append(regexpRules, rule) // the only unhandled case is 2 and above default: - rule = rewrite.NewSimpleRule(args[0], strings.Join(args[1:], " ")) + rule = NewSimpleRule(args[0], strings.Join(args[1:], " ")) simpleRules = append(simpleRules, rule) } } diff --git a/middleware/roller.go b/middleware/roller.go index 995cabf91..81ff71c44 100644 --- a/middleware/roller.go +++ b/middleware/roller.go @@ -2,10 +2,45 @@ package middleware import ( "io" + "strconv" + "github.com/mholt/caddy" "gopkg.in/natefinch/lumberjack.v2" ) +func ParseRoller(c *caddy.Controller) (*LogRoller, error) { + var size, age, keep int + // This is kind of a hack to support nested blocks: + // As we are already in a block: either log or errors, + // c.nesting > 0 but, as soon as c meets a }, it thinks + // the block is over and return false for c.NextBlock. + for c.NextBlock() { + what := c.Val() + if !c.NextArg() { + return nil, c.ArgErr() + } + value := c.Val() + var err error + switch what { + case "size": + size, err = strconv.Atoi(value) + case "age": + age, err = strconv.Atoi(value) + case "keep": + keep, err = strconv.Atoi(value) + } + if err != nil { + return nil, err + } + } + return &LogRoller{ + MaxSize: size, + MaxAge: age, + MaxBackups: keep, + LocalTime: true, + }, nil +} + // LogRoller implements a middleware that provides a rolling logger. type LogRoller struct { Filename string diff --git a/core/setup/secondary.go b/middleware/secondary/setup.go similarity index 64% rename from core/setup/secondary.go rename to middleware/secondary/setup.go index e1f54a651..5550cf11d 100644 --- a/core/setup/secondary.go +++ b/middleware/secondary/setup.go @@ -1,22 +1,30 @@ -package setup +package secondary import ( + "github.com/miekg/coredns/core/dnsserver" "github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware/file" - "github.com/miekg/coredns/middleware/secondary" + + "github.com/mholt/caddy" ) -// Secondary sets up the secondary middleware. -func Secondary(c *Controller) (middleware.Middleware, error) { +func init() { + caddy.RegisterPlugin("secondary", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { zones, err := secondaryParse(c) if err != nil { - return nil, err + return err } // Add startup functions to retrieve the zone and keep it up to date. for _, n := range zones.Names { if len(zones.Z[n].TransferFrom) > 0 { - c.Startup = append(c.Startup, func() error { + c.OnStartup(func() error { zones.Z[n].StartupOnce.Do(func() { zones.Z[n].TransferIn() go func() { @@ -28,19 +36,22 @@ func Secondary(c *Controller) (middleware.Middleware, error) { } } - return func(next middleware.Handler) middleware.Handler { - return secondary.Secondary{file.File{Next: next, Zones: zones}} - }, nil + dnsserver.GetConfig(c).AddMiddleware(func(next dnsserver.Handler) dnsserver.Handler { + return Secondary{file.File{Next: next, Zones: zones}} + }) + return nil } -func secondaryParse(c *Controller) (file.Zones, error) { +func secondaryParse(c *caddy.Controller) (file.Zones, error) { z := make(map[string]*file.Zone) names := []string{} + origins := []string{} for c.Next() { if c.Val() == "secondary" { // secondary [origin] - origins := c.ServerBlockHosts + origins = make([]string, len(c.ServerBlockKeys)) + copy(origins, c.ServerBlockKeys) args := c.RemainingArgs() if len(args) > 0 { origins = args @@ -52,7 +63,7 @@ func secondaryParse(c *Controller) (file.Zones, error) { } for c.NextBlock() { - t, f, e := transferParse(c) + t, f, e := file.TransferParse(c) if e != nil { return file.Zones{}, e } diff --git a/plugin_generate.go b/plugin_generate.go new file mode 100644 index 000000000..fe8fef75d --- /dev/null +++ b/plugin_generate.go @@ -0,0 +1,81 @@ +//+build ignore + +package main + +import ( + "bytes" + "errors" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "io/ioutil" + "log" + "strconv" +) + +func AddImportToFile(file, imprt string) ([]byte, error) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, file, nil, parser.ParseComments) + if err != nil { + return nil, err + } + + for _, s := range f.Imports { + iSpec := &ast.ImportSpec{Path: &ast.BasicLit{Value: s.Path.Value}} + if iSpec.Path.Value == strconv.Quote(imprt) { + return nil, errors.New("coredns import already found") + } + } + + for i := 0; i < len(f.Decls); i++ { + d := f.Decls[i] + + switch d.(type) { + case *ast.FuncDecl: + // No action + case *ast.GenDecl: + dd := d.(*ast.GenDecl) + + // IMPORT Declarations + if dd.Tok == token.IMPORT { + // Add the new import + iSpec := &ast.ImportSpec{Name: &ast.Ident{Name: "_"}, Path: &ast.BasicLit{Value: strconv.Quote(imprt)}} + dd.Specs = append(dd.Specs, iSpec) + break + } + } + } + + ast.SortImports(fset, f) + + out, err := GenerateFile(fset, f) + return out, err +} + +func GenerateFile(fset *token.FileSet, file *ast.File) ([]byte, error) { + var output []byte + buffer := bytes.NewBuffer(output) + if err := printer.Fprint(buffer, fset, file); err != nil { + return nil, err + } + + return buffer.Bytes(), nil +} + +const ( + coredns = "github.com/miekg/coredns/core" + // If everything is OK and we are sitting in CoreDNS' dir, this is where run.go should be. + caddyrun = "../../mholt/caddy/caddy/caddymain/run.go" +) + +func main() { + out, err := AddImportToFile(caddyrun, coredns) + if err != nil { + log.Printf("failed to add import: %s", err) + return + } + if err := ioutil.WriteFile(caddyrun, out, 0644); err != nil { + log.Fatalf("failed to write go file: %s", err) + } +} diff --git a/server/server.go b/server/server.go index 0ba6f7f01..b0f89468e 100644 --- a/server/server.go +++ b/server/server.go @@ -28,13 +28,10 @@ import ( // the same address and the listener may be stopped for // graceful termination (POSIX only). type Server struct { - Addr string // Address we listen on - mux *dns.ServeMux - server [2]*dns.Server // by convention 0 is tcp and 1 is udp - - tcp net.Listener - udp net.PacketConn - listenerMu sync.Mutex // protects listener and packetconn + Addr string // Address we listen on + mux *dns.ServeMux + server [2]*dns.Server // by convention 0 is tcp and 1 is udp + listenerMu sync.Mutex // protects listener and packetconn inside server tls bool // whether this server is serving all HTTPS hosts or not TLSConfig *tls.Config @@ -115,16 +112,6 @@ func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server, return s, nil } -// LocalAddr return the addresses where the server is bound to. The TCP listener -// address is the first returned, the UDP conn address the second. -func (s *Server) LocalAddr() (net.Addr, net.Addr) { - s.listenerMu.Lock() - tcp := s.tcp.Addr() - udp := s.udp.LocalAddr() - s.listenerMu.Unlock() - return tcp, udp -} - // Serve starts the server with an existing listener. It blocks until the server stops. func (s *Server) Serve(ln net.Listener, pc net.PacketConn) error { err := s.setup() @@ -134,9 +121,7 @@ func (s *Server) Serve(ln net.Listener, pc net.PacketConn) error { } s.listenerMu.Lock() s.server[0] = &dns.Server{Listener: ln, Net: "tcp", Handler: s.mux} - s.tcp = ln s.server[1] = &dns.Server{PacketConn: pc, Net: "udp", Handler: s.mux} - s.udp = pc s.listenerMu.Unlock() go func() { @@ -168,9 +153,7 @@ func (s *Server) ListenAndServe() error { s.listenerMu.Lock() s.server[0] = &dns.Server{Listener: l, Net: "tcp", Handler: s.mux} - s.tcp = l s.server[1] = &dns.Server{PacketConn: pc, Net: "udp", Handler: s.mux} - s.udp = pc s.listenerMu.Unlock() go func() { @@ -252,17 +235,17 @@ func (s *Server) Stop() (err error) { // Close the listener now; this stops the server without delay s.listenerMu.Lock() - if s.tcp != nil { - err = s.tcp.Close() - } - if s.udp != nil { - err = s.udp.Close() - } + defer s.listenerMu.Unlock() for _, s1 := range s.server { + if s1.Listener != nil { + err = s1.Listener.Close() + } + if s1.PacketConn != nil { + err = s1.PacketConn.Close() + } err = s1.Shutdown() } - s.listenerMu.Unlock() return } @@ -280,8 +263,8 @@ func (s *Server) WaitUntilStarted() { func (s *Server) ListenerFd() *os.File { s.listenerMu.Lock() defer s.listenerMu.Unlock() - if s.tcp != nil { - file, _ := s.tcp.(*net.TCPListener).File() + if s.server[0].Listener != nil { + file, _ := s.server[0].Listener.(*net.TCPListener).File() return file } return nil @@ -293,8 +276,8 @@ func (s *Server) ListenerFd() *os.File { func (s *Server) PacketConnFd() *os.File { s.listenerMu.Lock() defer s.listenerMu.Unlock() - if s.udp != nil { - file, _ := s.udp.(*net.UDPConn).File() + if s.server[1].PacketConn != nil { + file, _ := s.server[1].PacketConn.(*net.UDPConn).File() return file } return nil diff --git a/test/etcd_test.go b/test/etcd_test.go index ceec936e7..6beb4cff2 100644 --- a/test/etcd_test.go +++ b/test/etcd_test.go @@ -46,13 +46,18 @@ func TestEtcdStubAndProxyLookup(t *testing.T) { proxy . 8.8.8.8:53 }` - etc := etcdMiddleware() - ex, _, udp, err := Server(t, corefile) + ex, err := CoreDNSServer(corefile) if err != nil { - t.Fatalf("Could get server: %s", err) + t.Fatalf("could not get CoreDNS serving instance: %s", err) + } + + udp, _ := CoreDNSServerPorts(ex, 0) + if udp == "" { + t.Fatalf("could not get udp listening port") } defer ex.Stop() + etc := etcdMiddleware() log.SetOutput(ioutil.Discard) var ctx = context.TODO() diff --git a/test/fail_start_test.go b/test/fail_start_test.go deleted file mode 100644 index aa1b137af..000000000 --- a/test/fail_start_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package test - -import ( - "testing" - - "github.com/miekg/coredns/core" -) - -// Bind to low port should fail. -func TestFailStartServer(t *testing.T) { - corefile := `.:53 { - chaos CoreDNS-001 miek@miek.nl -} -` - srv, _ := core.TestServer(t, corefile) - err := srv.ListenAndServe() - if err == nil { - srv.Stop() - t.Fatalf("Low port startup should fail") - } -} diff --git a/test/file.go b/test/file.go new file mode 100644 index 000000000..b6068a32b --- /dev/null +++ b/test/file.go @@ -0,0 +1,20 @@ +package test + +import ( + "io/ioutil" + "os" + "testing" +) + +// TempFile will create a temporary file on disk and returns the name and a cleanup function to remove it later. +func TempFile(t *testing.T, dir, content string) (string, func(), error) { + f, err := ioutil.TempFile(dir, "go-test-tmpfile") + if err != nil { + return "", nil, err + } + if err := ioutil.WriteFile(f.Name(), []byte(content), 0644); err != nil { + return "", nil, err + } + rmFunc := func() { os.Remove(f.Name()) } + return f.Name(), rmFunc, nil +} diff --git a/test/file_test.go b/test/file_test.go new file mode 100644 index 000000000..950ea7cff --- /dev/null +++ b/test/file_test.go @@ -0,0 +1,11 @@ +package test + +import "testing" + +func TestTempFile(t *testing.T) { + _, f, e := TempFile(t, ".", "test") + if e != nil { + t.Fatalf("failed to create temp file: %s", e) + } + defer f() +} diff --git a/test/helpers.go b/test/helpers.go new file mode 100644 index 000000000..01a6f156b --- /dev/null +++ b/test/helpers.go @@ -0,0 +1,250 @@ +package test + +import ( + "testing" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +type Sect int + +const ( + Answer Sect = iota + Ns + Extra +) + +type RRSet []dns.RR + +func (p RRSet) Len() int { return len(p) } +func (p RRSet) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p RRSet) Less(i, j int) bool { return p[i].String() < p[j].String() } + +// If the TTL of a record is 303 we don't care what the TTL is. +type Case struct { + Qname string + Qtype uint16 + Rcode int + Do bool + Answer []dns.RR + Ns []dns.RR + Extra []dns.RR +} + +func (c Case) Msg() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(c.Qname), c.Qtype) + if c.Do { + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + o.SetDo() + o.SetUDPSize(4096) + m.Extra = []dns.RR{o} + } + return m +} + +func A(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) } +func AAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) } +func CNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) } +func SRV(rr string) *dns.SRV { r, _ := dns.NewRR(rr); return r.(*dns.SRV) } +func SOA(rr string) *dns.SOA { r, _ := dns.NewRR(rr); return r.(*dns.SOA) } +func NS(rr string) *dns.NS { r, _ := dns.NewRR(rr); return r.(*dns.NS) } +func PTR(rr string) *dns.PTR { r, _ := dns.NewRR(rr); return r.(*dns.PTR) } +func TXT(rr string) *dns.TXT { r, _ := dns.NewRR(rr); return r.(*dns.TXT) } +func MX(rr string) *dns.MX { r, _ := dns.NewRR(rr); return r.(*dns.MX) } +func RRSIG(rr string) *dns.RRSIG { r, _ := dns.NewRR(rr); return r.(*dns.RRSIG) } +func NSEC(rr string) *dns.NSEC { r, _ := dns.NewRR(rr); return r.(*dns.NSEC) } +func DNSKEY(rr string) *dns.DNSKEY { r, _ := dns.NewRR(rr); return r.(*dns.DNSKEY) } + +func OPT(bufsize int, do bool) *dns.OPT { + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + o.SetVersion(0) + o.SetUDPSize(uint16(bufsize)) + if do { + o.SetDo() + } + return o +} + +func Header(t *testing.T, tc Case, resp *dns.Msg) bool { + if resp.Rcode != tc.Rcode { + t.Errorf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode]) + return false + } + + if len(resp.Answer) != len(tc.Answer) { + t.Errorf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer)) + return false + } + if len(resp.Ns) != len(tc.Ns) { + t.Errorf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns)) + return false + } + if len(resp.Extra) != len(tc.Extra) { + t.Errorf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra)) + return false + } + return true +} + +func Section(t *testing.T, tc Case, sect Sect, rr []dns.RR) bool { + section := []dns.RR{} + switch sect { + case 0: + section = tc.Answer + case 1: + section = tc.Ns + case 2: + section = tc.Extra + } + + for i, a := range rr { + if a.Header().Name != section[i].Header().Name { + t.Errorf("rr %d should have a Header Name of %q, but has %q", i, section[i].Header().Name, a.Header().Name) + return false + } + // 303 signals: don't care what the ttl is. + if section[i].Header().Ttl != 303 && a.Header().Ttl != section[i].Header().Ttl { + if _, ok := section[i].(*dns.OPT); !ok { + // we check edns0 bufize on this one + t.Errorf("rr %d should have a Header TTL of %d, but has %d", i, section[i].Header().Ttl, a.Header().Ttl) + return false + } + } + if a.Header().Rrtype != section[i].Header().Rrtype { + t.Errorf("rr %d should have a header rr type of %d, but has %d", i, section[i].Header().Rrtype, a.Header().Rrtype) + return false + } + + switch x := a.(type) { + case *dns.SRV: + if x.Priority != section[i].(*dns.SRV).Priority { + t.Errorf("rr %d should have a Priority of %d, but has %d", i, section[i].(*dns.SRV).Priority, x.Priority) + return false + } + if x.Weight != section[i].(*dns.SRV).Weight { + t.Errorf("rr %d should have a Weight of %d, but has %d", i, section[i].(*dns.SRV).Weight, x.Weight) + return false + } + if x.Port != section[i].(*dns.SRV).Port { + t.Errorf("rr %d should have a Port of %d, but has %d", i, section[i].(*dns.SRV).Port, x.Port) + return false + } + if x.Target != section[i].(*dns.SRV).Target { + t.Errorf("rr %d should have a Target of %q, but has %q", i, section[i].(*dns.SRV).Target, x.Target) + return false + } + case *dns.RRSIG: + if x.TypeCovered != section[i].(*dns.RRSIG).TypeCovered { + t.Errorf("rr %d should have a TypeCovered of %d, but has %d", i, section[i].(*dns.RRSIG).TypeCovered, x.TypeCovered) + return false + } + if x.Labels != section[i].(*dns.RRSIG).Labels { + t.Errorf("rr %d should have a Labels of %d, but has %d", i, section[i].(*dns.RRSIG).Labels, x.Labels) + return false + } + if x.SignerName != section[i].(*dns.RRSIG).SignerName { + t.Errorf("rr %d should have a SignerName of %d, but has %d", i, section[i].(*dns.RRSIG).SignerName, x.SignerName) + return false + } + case *dns.NSEC: + if x.NextDomain != section[i].(*dns.NSEC).NextDomain { + t.Errorf("rr %d should have a NextDomain of %d, but has %d", i, section[i].(*dns.NSEC).NextDomain, x.NextDomain) + return false + } + // TypeBitMap + case *dns.A: + if x.A.String() != section[i].(*dns.A).A.String() { + t.Errorf("rr %d should have a Address of %q, but has %q", i, section[i].(*dns.A).A.String(), x.A.String()) + return false + } + case *dns.AAAA: + if x.AAAA.String() != section[i].(*dns.AAAA).AAAA.String() { + t.Errorf("rr %d should have a Address of %q, but has %q", i, section[i].(*dns.AAAA).AAAA.String(), x.AAAA.String()) + return false + } + case *dns.TXT: + for j, txt := range x.Txt { + if txt != section[i].(*dns.TXT).Txt[j] { + t.Errorf("rr %d should have a Txt of %q, but has %q", i, section[i].(*dns.TXT).Txt[j], txt) + return false + } + } + case *dns.SOA: + tt := section[i].(*dns.SOA) + if x.Ns != tt.Ns { + t.Errorf("SOA nameserver should be %q, but is %q", x.Ns, tt.Ns) + return false + } + case *dns.PTR: + tt := section[i].(*dns.PTR) + if x.Ptr != tt.Ptr { + t.Errorf("PTR ptr should be %q, but is %q", x.Ptr, tt.Ptr) + return false + } + case *dns.CNAME: + tt := section[i].(*dns.CNAME) + if x.Target != tt.Target { + t.Errorf("CNAME target should be %q, but is %q", x.Target, tt.Target) + return false + } + case *dns.MX: + tt := section[i].(*dns.MX) + if x.Mx != tt.Mx { + t.Errorf("MX Mx should be %q, but is %q", x.Mx, tt.Mx) + return false + } + if x.Preference != tt.Preference { + t.Errorf("MX Preference should be %q, but is %q", x.Preference, tt.Preference) + return false + } + case *dns.NS: + tt := section[i].(*dns.NS) + if x.Ns != tt.Ns { + t.Errorf("NS nameserver should be %q, but is %q", x.Ns, tt.Ns) + return false + } + case *dns.OPT: + tt := section[i].(*dns.OPT) + if x.UDPSize() != tt.UDPSize() { + t.Errorf("OPT UDPSize should be %d, but is %d", tt.UDPSize(), x.UDPSize()) + return false + } + if x.Do() != tt.Do() { + t.Errorf("OPT DO should be %t, but is %t", tt.Do(), x.Do()) + return false + } + } + } + return true +} + +func ErrorHandler() Handler { + return HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeServerFailure) + w.WriteMsg(m) + return dns.RcodeServerFailure, nil + }) +} + +// Copied here to prevent an import cycle. +type ( + // HandlerFunc is a convenience type like dns.HandlerFunc, except + // ServeDNS returns an rcode and an error. + HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) + + Handler interface { + ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) + } +) + +// ServeDNS implements the Handler interface. +func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + return f(ctx, w, r) +} diff --git a/test/helpers_test.go b/test/helpers_test.go new file mode 100644 index 000000000..2aa1fe106 --- /dev/null +++ b/test/helpers_test.go @@ -0,0 +1,5 @@ +package test + +import "testing" + +func TestA(t *testing.T) { A("miek.nl. IN A 127.0.0.1") } // should not crash diff --git a/test/kubernetes_test.go b/test/kubernetes_test.go index 55939e350..009b50657 100644 --- a/test/kubernetes_test.go +++ b/test/kubernetes_test.go @@ -3,12 +3,12 @@ package test import ( - "fmt" "io/ioutil" "log" "testing" "github.com/miekg/coredns/middleware/kubernetes/k8stest" + "github.com/miekg/dns" ) @@ -64,9 +64,8 @@ var testdataLookupSRV = []struct { } func TestK8sIntegration(t *testing.T) { - t.Log(" === RUN testLookupA") + // subtests here (Go 1.7 feature). testLookupA(t) - t.Log(" === RUN testLookupSRV") testLookupSRV(t) } @@ -75,7 +74,7 @@ func testLookupA(t *testing.T) { t.Skip("Skipping Kubernetes Integration tests. Kubernetes is not running") } - coreFile := + corefile := `.:0 { kubernetes coredns.local { endpoint http://localhost:8080 @@ -83,16 +82,20 @@ func testLookupA(t *testing.T) { } ` - server, _, udp, err := Server(t, coreFile) + server, err := CoreDNSServer(corefile) if err != nil { - t.Fatal("Could not get server: %s", err) + t.Fatalf("could not get CoreDNS serving instance: %s", err) + } + + udp, _ := CoreDNSServerPorts(server, 0) + if udp == "" { + t.Fatalf("could not get udp listening port") } defer server.Stop() log.SetOutput(ioutil.Discard) for _, testData := range testdataLookupA { - t.Logf("[log] Testing query string: '%v'\n", testData.Query) dnsClient := new(dns.Client) dnsMessage := new(dns.Msg) @@ -125,7 +128,7 @@ func testLookupSRV(t *testing.T) { t.Skip("Skipping Kubernetes Integration tests. Kubernetes is not running") } - coreFile := + corefile := `.:0 { kubernetes coredns.local { endpoint http://localhost:8080 @@ -133,9 +136,13 @@ func testLookupSRV(t *testing.T) { } ` - server, _, udp, err := Server(t, coreFile) + server, err := CoreDNSServer(corefile) if err != nil { - t.Fatal("Could not get server: %s", err) + t.Fatalf("could not get CoreDNS serving instance: %s", err) + } + udp, _ := CoreDNSServerPorts(server, 0) + if udp == "" { + t.Fatalf("could not get udp listening port") } defer server.Stop() @@ -144,7 +151,6 @@ func testLookupSRV(t *testing.T) { // TODO: Add checks for A records in additional section for _, testData := range testdataLookupSRV { - t.Logf("[log] Testing query string: '%v'\n", testData.Query) dnsClient := new(dns.Client) dnsMessage := new(dns.Msg) @@ -158,7 +164,6 @@ func testLookupSRV(t *testing.T) { // Count SRV records in the answer section srvRecordCount := 0 for _, a := range res.Answer { - fmt.Printf("RR: %v\n", a) if a.Header().Rrtype == dns.TypeSRV { srvRecordCount++ } diff --git a/test/middleware_dnssec_test.go b/test/middleware_dnssec_test.go index 434e7ad56..afde72a54 100644 --- a/test/middleware_dnssec_test.go +++ b/test/middleware_dnssec_test.go @@ -29,11 +29,12 @@ func TestLookupBalanceRewriteCacheDnssec(t *testing.T) { loadbalance } ` - ex, _, udp, err := Server(t, corefile) + ex, err := CoreDNSServer(corefile) if err != nil { - t.Errorf("Could get server to start: %s", err) - return + t.Fatalf("could not get CoreDNS serving instance: %s", err) } + + udp, _ := CoreDNSServerPorts(ex, 0) defer ex.Stop() log.SetOutput(ioutil.Discard) diff --git a/test/middleware_test.go b/test/middleware_test.go index ec90f0d71..f7fefe78a 100644 --- a/test/middleware_test.go +++ b/test/middleware_test.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" ) -func BenchmarkLookupBalanceRewriteCache(b *testing.B) { +func benchmarkLookupBalanceRewriteCache(b *testing.B) { t := new(testing.T) name, rm, err := test.TempFile(t, ".", exampleOrg) if err != nil { @@ -24,10 +24,12 @@ func BenchmarkLookupBalanceRewriteCache(b *testing.B) { loadbalance } ` - ex, _, udp, err := Server(t, corefile) + + ex, err := CoreDNSServer(corefile) if err != nil { - t.Fatalf("Could get server: %s", err) + t.Fatalf("could not get CoreDNS serving instance: %s", err) } + udp, _ := CoreDNSServerPorts(ex, 0) defer ex.Stop() log.SetOutput(ioutil.Discard) diff --git a/test/proxy_test.go b/test/proxy_test.go index 56ef159fb..ca04e1ae8 100644 --- a/test/proxy_test.go +++ b/test/proxy_test.go @@ -28,14 +28,20 @@ func TestLookupProxy(t *testing.T) { defer rm() corefile := `example.org:0 { - file ` + name + ` + file ` + name + ` } ` - ex, _, udp, err := Server(t, corefile) + + i, err := CoreDNSServer(corefile) if err != nil { - t.Fatalf("Could get server: %s", err) + t.Fatalf("could not get CoreDNS serving instance: %s", err) } - defer ex.Stop() + + udp, _ := CoreDNSServerPorts(i, 0) + if udp == "" { + t.Fatalf("could not get udp listening port") + } + defer i.Stop() log.SetOutput(ioutil.Discard) @@ -43,8 +49,7 @@ func TestLookupProxy(t *testing.T) { state := middleware.State{W: &test.ResponseWriter{}, Req: new(dns.Msg)} resp, err := p.Lookup(state, "example.org.", dns.TypeA) if err != nil { - t.Error("Expected to receive reply, but didn't") - return + t.Fatal("Expected to receive reply, but didn't") } // expect answer section with A record in it if len(resp.Answer) == 0 { diff --git a/test/responsewriter.go b/test/responsewriter.go new file mode 100644 index 000000000..fb70d7e8d --- /dev/null +++ b/test/responsewriter.go @@ -0,0 +1,28 @@ +package test + +import ( + "net" + + "github.com/miekg/dns" +) + +type ResponseWriter struct{} + +func (t *ResponseWriter) LocalAddr() net.Addr { + ip := net.ParseIP("127.0.0.1") + port := 53 + return &net.UDPAddr{IP: ip, Port: port, Zone: ""} +} + +func (t *ResponseWriter) RemoteAddr() net.Addr { + ip := net.ParseIP("10.240.0.1") + port := 40212 + return &net.UDPAddr{IP: ip, Port: port, Zone: ""} +} + +func (t *ResponseWriter) WriteMsg(m *dns.Msg) error { return nil } +func (t *ResponseWriter) Write(buf []byte) (int, error) { return len(buf), nil } +func (t *ResponseWriter) Close() error { return nil } +func (t *ResponseWriter) TsigStatus() error { return nil } +func (t *ResponseWriter) TsigTimersOnly(bool) { return } +func (t *ResponseWriter) Hijack() { return } diff --git a/test/server.go b/test/server.go new file mode 100644 index 000000000..324ffdd8f --- /dev/null +++ b/test/server.go @@ -0,0 +1,91 @@ +package test + +import ( + "net" + "sync" + "testing" + "time" + + _ "github.com/miekg/coredns/core" + + "github.com/mholt/caddy" + "github.com/miekg/dns" +) + +func TCPServer(t *testing.T, laddr string) (*dns.Server, string, error) { + l, err := net.Listen("tcp", laddr) + if err != nil { + return nil, "", err + } + + server := &dns.Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour} + + waitLock := sync.Mutex{} + waitLock.Lock() + server.NotifyStartedFunc = func() { t.Logf("started TCP server on %s", l.Addr()); waitLock.Unlock() } + + go func() { + server.ActivateAndServe() + l.Close() + }() + + waitLock.Lock() + return server, l.Addr().String(), nil +} + +func UDPServer(t *testing.T, laddr string) (*dns.Server, string, error) { + pc, err := net.ListenPacket("udp", laddr) + if err != nil { + return nil, "", err + } + server := &dns.Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour} + + waitLock := sync.Mutex{} + waitLock.Lock() + server.NotifyStartedFunc = func() { t.Logf("started UDP server on %s", pc.LocalAddr()); waitLock.Unlock() } + + go func() { + server.ActivateAndServe() + pc.Close() + }() + + waitLock.Lock() + return server, pc.LocalAddr().String(), nil +} + +// CoreDNSServer returns a test server. It just takes a normal Corefile as input. +func CoreDNSServer(corefile string) (*caddy.Instance, error) { return caddy.Start(NewInput(corefile)) } + +// CoreDNSSserverStop stops a server. +func CoreDNSServerStop(i *caddy.Instance) { i.Stop() } + +// CoreDNSServeRPorts returns the ports the instance is listening on. The integer k indicates +// which ServerListener you want. +func CoreDNSServerPorts(i *caddy.Instance, k int) (udp, tcp string) { + srvs := i.Servers() + if len(srvs) < k+1 { + return "", "" + } + u := srvs[k].LocalAddr() + t := srvs[k].Addr() + + if u != nil { + udp = u.String() + } + if t != nil { + tcp = t.String() + } + return +} + +type Input struct { + corefile []byte +} + +func NewInput(corefile string) *Input { + return &Input{corefile: []byte(corefile)} +} + +func (i *Input) Body() []byte { return i.corefile } +func (i *Input) Path() string { return "Corefile" } +func (i *Input) ServerType() string { return "dns" } diff --git a/test/server_test.go b/test/server_test.go index 6a86d022b..a03285bf3 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -12,24 +12,28 @@ func TestProxyToChaosServer(t *testing.T) { chaos CoreDNS-001 miek@miek.nl } ` - chaos, tcpCH, udpCH, err := Server(t, corefile) + chaos, err := CoreDNSServer(corefile) if err != nil { - t.Fatalf("Could get server: %s", err) + t.Fatalf("could not get CoreDNS serving instance: %s", err) } + + udpChaos, tcpChaos := CoreDNSServerPorts(chaos, 0) defer chaos.Stop() corefileProxy := `.:0 { - proxy . ` + udpCH + ` + proxy . ` + udpChaos + ` } ` - proxy, _, udp, err := Server(t, corefileProxy) + proxy, err := CoreDNSServer(corefileProxy) if err != nil { - t.Fatalf("Could get server: %s", err) + t.Fatalf("could not get CoreDNS serving instance") } + + udp, _ := CoreDNSServerPorts(proxy, 0) defer proxy.Stop() - chaosTest(t, udpCH, "udp") - chaosTest(t, tcpCH, "tcp") + chaosTest(t, udpChaos, "udp") + chaosTest(t, tcpChaos, "tcp") chaosTest(t, udp, "udp") // chaosTest(t, tcp, "tcp"), commented out because we use the original transport to reach the diff --git a/test/tests.go b/test/tests.go index d38bf955f..0f9d12bae 100644 --- a/test/tests.go +++ b/test/tests.go @@ -1,12 +1,7 @@ package test import ( - "testing" - "time" - - "github.com/miekg/coredns/core" "github.com/miekg/coredns/middleware" - "github.com/miekg/coredns/server" "github.com/miekg/dns" ) @@ -25,16 +20,3 @@ func Exchange(m *dns.Msg, server, net string) (*dns.Msg, error) { c.Net = net return middleware.Exchange(c, m, server) } - -// Server returns a test server and the tcp and udp listeners addresses. -func Server(t *testing.T, corefile string) (*server.Server, string, string, error) { - srv, err := core.TestServer(t, corefile) - if err != nil { - return nil, "", "", err - } - go srv.ListenAndServe() - - time.Sleep(1 * time.Second) // yeah... I regret nothing - tcp, udp := srv.LocalAddr() - return srv, tcp.String(), udp.String(), nil -}