diff --git a/core/dnsserver/address.go b/core/dnsserver/address.go index 8f544e97a..36894aeea 100644 --- a/core/dnsserver/address.go +++ b/core/dnsserver/address.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/miekg/dns" ) @@ -27,43 +28,13 @@ func (z zoneAddr) String() string { return s } -// Transport returns the protocol of the string s -func Transport(s string) string { - switch { - case strings.HasPrefix(s, TransportTLS+"://"): - return TransportTLS - case strings.HasPrefix(s, TransportDNS+"://"): - return TransportDNS - case strings.HasPrefix(s, TransportGRPC+"://"): - return TransportGRPC - case strings.HasPrefix(s, TransportHTTPS+"://"): - return TransportHTTPS - } - return TransportDNS -} - // normalizeZone parses an zone string into a structured format with separate // host, and port portions, as well as the original input string. func normalizeZone(str string) (zoneAddr, error) { var err error - // Default to DNS if there isn't a transport protocol prefix. - trans := TransportDNS - - switch { - case strings.HasPrefix(str, TransportTLS+"://"): - trans = TransportTLS - str = str[len(TransportTLS+"://"):] - case strings.HasPrefix(str, TransportDNS+"://"): - trans = TransportDNS - str = str[len(TransportDNS+"://"):] - case strings.HasPrefix(str, TransportGRPC+"://"): - trans = TransportGRPC - str = str[len(TransportGRPC+"://"):] - case strings.HasPrefix(str, TransportHTTPS+"://"): - trans = TransportHTTPS - str = str[len(TransportHTTPS+"://"):] - } + var trans string + trans, str = transport.Parse(str) host, port, ipnet, err := plugin.SplitHostPort(str) if err != nil { @@ -71,17 +42,15 @@ func normalizeZone(str string) (zoneAddr, error) { } if port == "" { - if trans == TransportDNS { + switch trans { + case transport.DNS: port = Port - } - if trans == TransportTLS { - port = TLSPort - } - if trans == TransportGRPC { - port = GRPCPort - } - if trans == TransportHTTPS { - port = HTTPSPort + case transport.TLS: + port = transport.TLSPort + case transport.GRPC: + port = transport.GRPCPort + case transport.HTTPS: + port = transport.HTTPSPort } } @@ -103,14 +72,6 @@ func SplitProtocolHostPort(address string) (protocol string, ip string, port str } } -// Supported transports. -const ( - TransportDNS = "dns" - TransportTLS = "tls" - TransportGRPC = "grpc" - TransportHTTPS = "https" -) - type zoneOverlap struct { registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAdddr is registered by its original key diff --git a/core/dnsserver/address_test.go b/core/dnsserver/address_test.go index a83824f52..6d4d0beab 100644 --- a/core/dnsserver/address_test.go +++ b/core/dnsserver/address_test.go @@ -192,21 +192,3 @@ func TestOverlapAddressChecker(t *testing.T) { } } } - -func TestTransport(t *testing.T) { - for i, test := range []struct { - input string - expected string - }{ - {"dns://.:53", TransportDNS}, - {"2003::1/64.:53", TransportDNS}, - {"grpc://example.org:1443 ", TransportGRPC}, - {"tls://example.org ", TransportTLS}, - {"https://example.org ", TransportHTTPS}, - } { - actual := Transport(test.input) - if actual != test.expected { - t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual) - } - } -} diff --git a/core/dnsserver/register.go b/core/dnsserver/register.go index ced2519af..47595b5e3 100644 --- a/core/dnsserver/register.go +++ b/core/dnsserver/register.go @@ -9,6 +9,7 @@ import ( "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/mholt/caddy" "github.com/mholt/caddy/caddyfile" @@ -111,29 +112,29 @@ func (h *dnsContext) MakeServers() ([]caddy.Server, error) { var servers []caddy.Server for addr, group := range groups { // switch on addr - switch Transport(addr) { - case TransportDNS: + switch tr, _ := transport.Parse(addr); tr { + case transport.DNS: s, err := NewServer(addr, group) if err != nil { return nil, err } servers = append(servers, s) - case TransportTLS: + case transport.TLS: s, err := NewServerTLS(addr, group) if err != nil { return nil, err } servers = append(servers, s) - case TransportGRPC: + case transport.GRPC: s, err := NewServergRPC(addr, group) if err != nil { return nil, err } servers = append(servers, s) - case TransportHTTPS: + case transport.HTTPS: s, err := NewServerHTTPS(addr, group) if err != nil { return nil, err @@ -234,16 +235,8 @@ func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) { return groups, nil } -const ( - // DefaultPort is the default port. - DefaultPort = "53" - // TLSPort is the default port for DNS-over-TLS. - TLSPort = "853" - // GRPCPort is the default port for DNS-over-gRPC. - GRPCPort = "443" - // HTTPSPort is the default port for DNS-over-HTTPS. - HTTPSPort = "443" -) +// DefaultPort is the default port. +const DefaultPort = "53" // These "soft defaults" are configurable by // command line flags, etc. diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go index 21a52f22c..47d406e31 100644 --- a/core/dnsserver/server.go +++ b/core/dnsserver/server.go @@ -15,6 +15,7 @@ import ( "github.com/coredns/coredns/plugin/pkg/log" "github.com/coredns/coredns/plugin/pkg/rcode" "github.com/coredns/coredns/plugin/pkg/trace" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -134,7 +135,7 @@ func (s *Server) ServePacket(p net.PacketConn) error { // Listen implements caddy.TCPServer interface. func (s *Server) Listen() (net.Listener, error) { - l, err := net.Listen("tcp", s.Addr[len(TransportDNS+"://"):]) + l, err := net.Listen("tcp", s.Addr[len(transport.DNS+"://"):]) if err != nil { return nil, err } @@ -143,7 +144,7 @@ func (s *Server) Listen() (net.Listener, error) { // ListenPacket implements caddy.UDPServer interface. func (s *Server) ListenPacket() (net.PacketConn, error) { - p, err := net.ListenPacket("udp", s.Addr[len(TransportDNS+"://"):]) + p, err := net.ListenPacket("udp", s.Addr[len(transport.DNS+"://"):]) if err != nil { return nil, err } diff --git a/core/dnsserver/server_grpc.go b/core/dnsserver/server_grpc.go index e5b87749d..7de36a5fd 100644 --- a/core/dnsserver/server_grpc.go +++ b/core/dnsserver/server_grpc.go @@ -8,6 +8,7 @@ import ( "net" "github.com/coredns/coredns/pb" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/plugin/pkg/watch" "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" @@ -73,7 +74,7 @@ func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil } // Listen implements caddy.TCPServer interface. func (s *ServergRPC) Listen() (net.Listener, error) { - l, err := net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):]) + l, err := net.Listen("tcp", s.Addr[len(transport.GRPC+"://"):]) if err != nil { return nil, err } @@ -90,7 +91,7 @@ func (s *ServergRPC) OnStartupComplete() { return } - out := startUpZones(TransportGRPC+"://", s.Addr, s.zones) + out := startUpZones(transport.GRPC+"://", s.Addr, s.zones) if out != "" { fmt.Print(out) } diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index 9b1eaaa7e..1e184e044 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -12,6 +12,7 @@ import ( "github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/coredns/coredns/plugin/pkg/doh" "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/plugin/pkg/transport" ) // ServerHTTPS represents an instance of a DNS-over-HTTPS server. @@ -60,7 +61,7 @@ func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil } // Listen implements caddy.TCPServer interface. func (s *ServerHTTPS) Listen() (net.Listener, error) { - l, err := net.Listen("tcp", s.Addr[len(TransportHTTPS+"://"):]) + l, err := net.Listen("tcp", s.Addr[len(transport.HTTPS+"://"):]) if err != nil { return nil, err } @@ -77,7 +78,7 @@ func (s *ServerHTTPS) OnStartupComplete() { return } - out := startUpZones(TransportHTTPS+"://", s.Addr, s.zones) + out := startUpZones(transport.HTTPS+"://", s.Addr, s.zones) if out != "" { fmt.Print(out) } diff --git a/core/dnsserver/server_tls.go b/core/dnsserver/server_tls.go index 0fd0c1fbe..a63ac848a 100644 --- a/core/dnsserver/server_tls.go +++ b/core/dnsserver/server_tls.go @@ -6,6 +6,8 @@ import ( "fmt" "net" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/miekg/dns" ) @@ -55,7 +57,7 @@ func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil } // Listen implements caddy.TCPServer interface. func (s *ServerTLS) Listen() (net.Listener, error) { - l, err := net.Listen("tcp", s.Addr[len(TransportTLS+"://"):]) + l, err := net.Listen("tcp", s.Addr[len(transport.TLS+"://"):]) if err != nil { return nil, err } @@ -72,7 +74,7 @@ func (s *ServerTLS) OnStartupComplete() { return } - out := startUpZones(TransportTLS+"://", s.Addr, s.zones) + out := startUpZones(transport.TLS+"://", s.Addr, s.zones) if out != "" { fmt.Print(out) } diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go index 6102bbe15..64edb395e 100644 --- a/plugin/forward/connect.go +++ b/plugin/forward/connect.go @@ -35,16 +35,16 @@ func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight in atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) } -func (t *transport) dialTimeout() time.Duration { +func (t *Transport) dialTimeout() time.Duration { return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) } -func (t *transport) updateDialTimeout(newDialTime time.Duration) { +func (t *Transport) updateDialTimeout(newDialTime time.Duration) { averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) } // Dial dials the address configured in transport, potentially reusing a connection or creating a new one. -func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { +func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) { // If tls has been configured; use it. if t.tlsConfig != nil { proto = "tcp-tls" diff --git a/plugin/forward/forward_test.go b/plugin/forward/forward_test.go index 82844f811..dfdde933f 100644 --- a/plugin/forward/forward_test.go +++ b/plugin/forward/forward_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/request" @@ -19,7 +20,7 @@ func TestForward(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, DNS) + p := NewProxy(s.Addr, transport.DNS) f := New() f.SetProxy(p) defer f.Close() @@ -51,7 +52,7 @@ func TestForwardRefused(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, DNS) + p := NewProxy(s.Addr, transport.DNS) f := New() f.SetProxy(p) defer f.Close() diff --git a/plugin/forward/health.go b/plugin/forward/health.go index 4d3278f6d..a64b74122 100644 --- a/plugin/forward/health.go +++ b/plugin/forward/health.go @@ -5,6 +5,8 @@ import ( "sync/atomic" "time" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/miekg/dns" ) @@ -17,10 +19,10 @@ type HealthChecker interface { // dnsHc is a health checker for a DNS endpoint (DNS, and DoT). type dnsHc struct{ c *dns.Client } -// NewHealthChecker returns a new HealthChecker based on protocol. -func NewHealthChecker(protocol int) HealthChecker { - switch protocol { - case DNS, TLS: +// NewHealthChecker returns a new HealthChecker based on transport. +func NewHealthChecker(trans string) HealthChecker { + switch trans { + case transport.DNS, transport.TLS: c := new(dns.Client) c.Net = "udp" c.ReadTimeout = 1 * time.Second diff --git a/plugin/forward/health_test.go b/plugin/forward/health_test.go index 75d57f285..59a0dde13 100644 --- a/plugin/forward/health_test.go +++ b/plugin/forward/health_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" @@ -25,7 +26,7 @@ func TestHealth(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, DNS) + p := NewProxy(s.Addr, transport.DNS) f := New() f.SetProxy(p) defer f.Close() @@ -65,7 +66,7 @@ func TestHealthTimeout(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, DNS) + p := NewProxy(s.Addr, transport.DNS) f := New() f.SetProxy(p) defer f.Close() @@ -109,7 +110,7 @@ func TestHealthFailTwice(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, DNS) + p := NewProxy(s.Addr, transport.DNS) f := New() f.SetProxy(p) defer f.Close() @@ -132,7 +133,7 @@ func TestHealthMaxFails(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, DNS) + p := NewProxy(s.Addr, transport.DNS) f := New() f.maxfails = 2 f.SetProxy(p) @@ -163,7 +164,7 @@ func TestHealthNoMaxFails(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, DNS) + p := NewProxy(s.Addr, transport.DNS) f := New() f.maxfails = 0 f.SetProxy(p) diff --git a/plugin/forward/lookup.go b/plugin/forward/lookup.go index cba24f85f..f6c9a0745 100644 --- a/plugin/forward/lookup.go +++ b/plugin/forward/lookup.go @@ -7,6 +7,7 @@ package forward import ( "context" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -81,7 +82,7 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M func NewLookup(addr []string) *Forward { f := New() for i := range addr { - p := NewProxy(addr[i], DNS) + p := NewProxy(addr[i], transport.DNS) f.SetProxy(p) } return f diff --git a/plugin/forward/lookup_test.go b/plugin/forward/lookup_test.go index 1968ef979..bb3cc4143 100644 --- a/plugin/forward/lookup_test.go +++ b/plugin/forward/lookup_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/request" @@ -19,7 +20,7 @@ func TestLookup(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, DNS) + p := NewProxy(s.Addr, transport.DNS) f := New() f.SetProxy(p) defer f.Close() diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go index 4da1514fe..fabdc70b2 100644 --- a/plugin/forward/persistent.go +++ b/plugin/forward/persistent.go @@ -15,10 +15,10 @@ type persistConn struct { used time.Time } -// transport hold the persistent cache. -type transport struct { +// Transport hold the persistent cache. +type Transport struct { avgDialTime int64 // kind of average time of dial time - conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. + conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. expire time.Duration // After this duration a connection is expired. addr string tlsConfig *tls.Config @@ -29,8 +29,8 @@ type transport struct { stop chan bool } -func newTransport(addr string) *transport { - t := &transport{ +func newTransport(addr string) *Transport { + t := &Transport{ avgDialTime: int64(defaultDialTimeout / 2), conns: make(map[string][]*persistConn), expire: defaultExpire, @@ -45,7 +45,7 @@ func newTransport(addr string) *transport { // len returns the number of connection, used for metrics. Can only be safely // used inside connManager() because of data races. -func (t *transport) len() int { +func (t *Transport) len() int { l := 0 for _, conns := range t.conns { l += len(conns) @@ -54,7 +54,7 @@ func (t *transport) len() int { } // connManagers manages the persistent connection cache for UDP and TCP. -func (t *transport) connManager() { +func (t *Transport) connManager() { ticker := time.NewTicker(t.expire) Wait: for { @@ -115,7 +115,7 @@ func closeConns(conns []*persistConn) { } // cleanup removes connections from cache. -func (t *transport) cleanup(all bool) { +func (t *Transport) cleanup(all bool) { staleTime := time.Now().Add(-t.expire) for proto, stack := range t.conns { if len(stack) == 0 { @@ -144,19 +144,19 @@ func (t *transport) cleanup(all bool) { } // Yield return the connection to transport for reuse. -func (t *transport) Yield(c *dns.Conn) { t.yield <- c } +func (t *Transport) Yield(c *dns.Conn) { t.yield <- c } // Start starts the transport's connection manager. -func (t *transport) Start() { go t.connManager() } +func (t *Transport) Start() { go t.connManager() } // Stop stops the transport's connection manager. -func (t *transport) Stop() { close(t.stop) } +func (t *Transport) Stop() { close(t.stop) } // SetExpire sets the connection expire time in transport. -func (t *transport) SetExpire(expire time.Duration) { t.expire = expire } +func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } // SetTLSConfig sets the TLS config in transport. -func (t *transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } +func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } const ( defaultExpire = 10 * time.Second diff --git a/plugin/forward/protocol.go b/plugin/forward/protocol.go deleted file mode 100644 index 338b60116..000000000 --- a/plugin/forward/protocol.go +++ /dev/null @@ -1,30 +0,0 @@ -package forward - -// Copied from coredns/core/dnsserver/address.go - -import ( - "strings" -) - -// protocol returns the protocol of the string s. The second string returns s -// with the prefix chopped off. -func protocol(s string) (int, string) { - switch { - case strings.HasPrefix(s, _tls+"://"): - return TLS, s[len(_tls)+3:] - case strings.HasPrefix(s, _dns+"://"): - return DNS, s[len(_dns)+3:] - } - return DNS, s -} - -// Supported protocols. -const ( - DNS = iota + 1 - TLS -) - -const ( - _dns = "dns" - _tls = "tls" -) diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go index ac74bf0f8..453dd015b 100644 --- a/plugin/forward/proxy.go +++ b/plugin/forward/proxy.go @@ -18,7 +18,7 @@ type Proxy struct { // Connection caching expire time.Duration - transport *transport + transport *Transport // health checking probe *up.Probe @@ -26,7 +26,7 @@ type Proxy struct { } // NewProxy returns a new proxy. -func NewProxy(addr string, protocol int) *Proxy { +func NewProxy(addr, trans string) *Proxy { p := &Proxy{ addr: addr, fails: 0, @@ -34,7 +34,7 @@ func NewProxy(addr string, protocol int) *Proxy { transport: newTransport(addr), avgRtt: int64(maxTimeout / 2), } - p.health = NewHealthChecker(protocol) + p.health = NewHealthChecker(trans) runtime.SetFinalizer(p, (*Proxy).finalizer) return p } diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go index d7af25aa0..794103516 100644 --- a/plugin/forward/proxy_test.go +++ b/plugin/forward/proxy_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/request" @@ -26,7 +27,7 @@ func TestProxyClose(t *testing.T) { ctx := context.TODO() for i := 0; i < 100; i++ { - p := NewProxy(s.Addr, DNS) + p := NewProxy(s.Addr, transport.DNS) p.start(hcInterval) go func() { p.Connect(ctx, state, options{}) }() @@ -95,7 +96,7 @@ func TestProxyTLSFail(t *testing.T) { } func TestProtocolSelection(t *testing.T) { - p := NewProxy("bad_address", DNS) + p := NewProxy("bad_address", transport.DNS) stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 9fe8a6c38..6179b0d2d 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -2,7 +2,6 @@ package forward import ( "fmt" - "net" "strconv" "time" @@ -11,6 +10,7 @@ import ( "github.com/coredns/coredns/plugin/metrics" "github.com/coredns/coredns/plugin/pkg/dnsutil" pkgtls "github.com/coredns/coredns/plugin/pkg/tls" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/mholt/caddy" "github.com/mholt/caddy/caddyfile" @@ -93,8 +93,6 @@ func parseForward(c *caddy.Controller) (*Forward, error) { func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) { f := New() - protocols := map[int]int{} - if !c.Args(&f.from) { return f, c.ArgErr() } @@ -105,41 +103,17 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) { return f, c.ArgErr() } - // A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies. - protocols = make(map[int]int) - for i := range to { - protocols[i], to[i] = protocol(to[i]) - } - - // If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make - // any sense anymore... For now: lets don't care. toHosts, err := dnsutil.ParseHostPortOrFile(to...) if err != nil { return f, err } - for i, h := range toHosts { - // Double check the port, if e.g. is 53 and the transport is TLS make it 853. - // This can be somewhat annoying because you *can't* have TLS on port 53 then. - switch protocols[i] { - case TLS: - h1, p, err := net.SplitHostPort(h) - if err != nil { - break - } - - // This is more of a bug in dnsutil.ParseHostPortOrFile that defaults to - // 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence - // Fix the port number here, back to what the user intended. - if p == "53" { - h = net.JoinHostPort(h1, "853") - } - } - - // We can't set tlsConfig here, because we haven't parsed it yet. - // We set it below at the end of parseBlock, use nil now. - p := NewProxy(h, protocols[i]) + transports := make([]string, len(toHosts)) + for i, host := range toHosts { + trans, h := transport.Parse(host) + p := NewProxy(h, trans) f.proxies = append(f.proxies, p) + transports[i] = trans } for c.NextBlock() { @@ -153,7 +127,7 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) { } for i := range f.proxies { // Only set this for proxies that need it. - if protocols[i] == TLS { + if transports[i] == transport.TLS { f.proxies[i].SetTLSConfig(f.tlsConfig) } f.proxies[i].SetExpire(f.expire) diff --git a/plugin/forward/truncated_test.go b/plugin/forward/truncated_test.go index b7ff47c14..40fc8185f 100644 --- a/plugin/forward/truncated_test.go +++ b/plugin/forward/truncated_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/request" @@ -34,7 +35,7 @@ func TestLookupTruncated(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, DNS) + p := NewProxy(s.Addr, transport.DNS) f := New() f.SetProxy(p) defer f.Close() @@ -88,9 +89,9 @@ func TestForwardTruncated(t *testing.T) { f := New() - p1 := NewProxy(s.Addr, DNS) + p1 := NewProxy(s.Addr, transport.DNS) f.SetProxy(p1) - p2 := NewProxy(s.Addr, DNS) + p2 := NewProxy(s.Addr, transport.DNS) f.SetProxy(p2) defer f.Close() diff --git a/plugin/kubernetes/README.md b/plugin/kubernetes/README.md index e85ce1d43..cb51ae458 100644 --- a/plugin/kubernetes/README.md +++ b/plugin/kubernetes/README.md @@ -95,7 +95,7 @@ kubernetes [ZONES...] { (only `to` is allow). **ADDRESS** must be denoted in CIDR notation (127.0.0.1/32 etc.) or just as plain addresses. The special wildcard `*` means: the entire internet. Sending DNS notifies is not supported. - [Deprecated](https://github.com/kubernetes/dns/blob/master/docs/specification.md#26---deprecated-records) pod records in the sub domain `pod.cluster.local` are not transferred. + [Deprecated](https://github.com/kubernetes/dns/blob/master/docs/specification.md#26---deprecated-records) pod records in the sub domain `pod.cluster.local` are not transferred. * `fallthrough` **[ZONES...]** If a query for a record in the zones for which the plugin is authoritative results in NXDOMAIN, normally that is what the response will be. However, if you specify this option, the query will instead be passed on down the plugin chain, which can include another plugin to handle diff --git a/plugin/normalize.go b/plugin/normalize.go index e44e55385..e38d5fa08 100644 --- a/plugin/normalize.go +++ b/plugin/normalize.go @@ -6,6 +6,8 @@ import ( "strconv" "strings" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/miekg/dns" ) @@ -61,22 +63,10 @@ type ( // Normalize will return the host portion of host, stripping // of any port or transport. The host will also be fully qualified and lowercased. func (h Host) Normalize() string { - s := string(h) + _, s = transport.Parse(s) - switch { - case strings.HasPrefix(s, TransportTLS+"://"): - s = s[len(TransportTLS+"://"):] - case strings.HasPrefix(s, TransportDNS+"://"): - s = s[len(TransportDNS+"://"):] - case strings.HasPrefix(s, TransportGRPC+"://"): - s = s[len(TransportGRPC+"://"):] - case strings.HasPrefix(s, TransportHTTPS+"://"): - s = s[len(TransportHTTPS+"://"):] - } - - // The error can be ignore here, because this function is called after the corefile - // has already been vetted. + // The error can be ignore here, because this function is called after the corefile has already been vetted. host, _, _, _ := SplitHostPort(s) return Name(host).Normalize() } @@ -138,11 +128,3 @@ func SplitHostPort(s string) (host, port string, ipnet *net.IPNet, err error) { } return host, port, n, nil } - -// Duplicated from core/dnsserver/address.go ! -const ( - TransportDNS = "dns" - TransportTLS = "tls" - TransportGRPC = "grpc" - TransportHTTPS = "https" -) diff --git a/plugin/pkg/dnsutil/host.go b/plugin/pkg/dnsutil/host.go index aaab586e8..b03b39586 100644 --- a/plugin/pkg/dnsutil/host.go +++ b/plugin/pkg/dnsutil/host.go @@ -5,15 +5,21 @@ import ( "net" "os" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/miekg/dns" ) -// ParseHostPortOrFile parses the strings in s, each string can either be a address, -// address:port or a filename. The address part is checked and the filename case a -// resolv.conf like file is parsed and the nameserver found are returned. +// ParseHostPortOrFile parses the strings in s, each string can either be a +// address, [scheme://]address:port or a filename. The address part is checked +// and in case of filename a resolv.conf like file is (assumed) and parsed and +// the nameservers found are returned. func ParseHostPortOrFile(s ...string) ([]string, error) { var servers []string - for _, host := range s { + for _, h := range s { + + trans, host := transport.Parse(h) + addr, _, err := net.SplitHostPort(host) if err != nil { // Parse didn't work, it is not a addr:port combo @@ -26,13 +32,23 @@ func ParseHostPortOrFile(s ...string) ([]string, error) { } return servers, fmt.Errorf("not an IP address or file: %q", host) } - ss := net.JoinHostPort(host, "53") + var ss string + switch trans { + case transport.DNS: + ss = net.JoinHostPort(host, "53") + case transport.TLS: + ss = transport.TLS + "://" + net.JoinHostPort(host, transport.TLSPort) + case transport.GRPC: + ss = transport.GRPC + "://" + net.JoinHostPort(host, transport.GRPCPort) + case transport.HTTPS: + ss = transport.HTTPS + "://" + net.JoinHostPort(host, transport.HTTPSPort) + } servers = append(servers, ss) continue } if net.ParseIP(addr) == nil { - // No an IP address. + // Not an IP address. ss, err := tryFile(host) if err == nil { servers = append(servers, ss...) @@ -40,7 +56,7 @@ func ParseHostPortOrFile(s ...string) ([]string, error) { } return servers, fmt.Errorf("not an IP address or file: %q", host) } - servers = append(servers, host) + servers = append(servers, h) } return servers, nil } diff --git a/plugin/pkg/transport/transport.go b/plugin/pkg/transport/transport.go new file mode 100644 index 000000000..690b7768c --- /dev/null +++ b/plugin/pkg/transport/transport.go @@ -0,0 +1,49 @@ +package transport + +import ( + "strings" +) + +// Parse returns the transport defined in s and a string where the +// transport prefix is removed (if there was any). If no transport is defined +// we default to TransportDNS +func Parse(s string) (transport string, addr string) { + switch { + case strings.HasPrefix(s, TLS+"://"): + s = s[len(TLS+"://"):] + return TLS, s + + case strings.HasPrefix(s, DNS+"://"): + s = s[len(DNS+"://"):] + return DNS, s + + case strings.HasPrefix(s, GRPC+"://"): + s = s[len(GRPC+"://"):] + return GRPC, s + + case strings.HasPrefix(s, HTTPS+"://"): + s = s[len(HTTPS+"://"):] + + return HTTPS, s + } + + return DNS, s +} + +// Supported transports. +const ( + DNS = "dns" + TLS = "tls" + GRPC = "grpc" + HTTPS = "https" +) + +// Port numbers for the various protocols +const ( + // TLSPort is the default port for DNS-over-TLS. + TLSPort = "853" + // GRPCPort is the default port for DNS-over-gRPC. + GRPCPort = "443" + // HTTPSPort is the default port for DNS-over-HTTPS. + HTTPSPort = "443" +) diff --git a/plugin/pkg/transport/transport_test.go b/plugin/pkg/transport/transport_test.go new file mode 100644 index 000000000..5f93266eb --- /dev/null +++ b/plugin/pkg/transport/transport_test.go @@ -0,0 +1,21 @@ +package transport + +import "testing" + +func TestParse(t *testing.T) { + for i, test := range []struct { + input string + expected string + }{ + {"dns://.:53", DNS}, + {"2003::1/64.:53", DNS}, + {"grpc://example.org:1443 ", GRPC}, + {"tls://example.org ", TLS}, + {"https://example.org ", HTTPS}, + } { + actual, _ := Parse(test.input) + if actual != test.expected { + t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual) + } + } +}