Cleanup ParseHostOrFile (#2100)

Create plugin/pkg/transport that holds the transport related functions.
This needed to be a new pkg to prevent cyclic import errors.

This cleans up a bunch of duplicated code in core/dnsserver that also
tried to parse a transport (now all done in transport.Parse).

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben 2018-09-19 07:29:37 +01:00 committed by GitHub
parent 2f1223c36a
commit c349446a23
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 182 additions and 221 deletions

View file

@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -27,43 +28,13 @@ func (z zoneAddr) String() string {
return s return s
} }
// Transport returns the protocol of the string s
func Transport(s string) string {
switch {
case strings.HasPrefix(s, TransportTLS+"://"):
return TransportTLS
case strings.HasPrefix(s, TransportDNS+"://"):
return TransportDNS
case strings.HasPrefix(s, TransportGRPC+"://"):
return TransportGRPC
case strings.HasPrefix(s, TransportHTTPS+"://"):
return TransportHTTPS
}
return TransportDNS
}
// normalizeZone parses an zone string into a structured format with separate // normalizeZone parses an zone string into a structured format with separate
// host, and port portions, as well as the original input string. // host, and port portions, as well as the original input string.
func normalizeZone(str string) (zoneAddr, error) { func normalizeZone(str string) (zoneAddr, error) {
var err error var err error
// Default to DNS if there isn't a transport protocol prefix. var trans string
trans := TransportDNS trans, str = transport.Parse(str)
switch {
case strings.HasPrefix(str, TransportTLS+"://"):
trans = TransportTLS
str = str[len(TransportTLS+"://"):]
case strings.HasPrefix(str, TransportDNS+"://"):
trans = TransportDNS
str = str[len(TransportDNS+"://"):]
case strings.HasPrefix(str, TransportGRPC+"://"):
trans = TransportGRPC
str = str[len(TransportGRPC+"://"):]
case strings.HasPrefix(str, TransportHTTPS+"://"):
trans = TransportHTTPS
str = str[len(TransportHTTPS+"://"):]
}
host, port, ipnet, err := plugin.SplitHostPort(str) host, port, ipnet, err := plugin.SplitHostPort(str)
if err != nil { if err != nil {
@ -71,17 +42,15 @@ func normalizeZone(str string) (zoneAddr, error) {
} }
if port == "" { if port == "" {
if trans == TransportDNS { switch trans {
case transport.DNS:
port = Port port = Port
} case transport.TLS:
if trans == TransportTLS { port = transport.TLSPort
port = TLSPort case transport.GRPC:
} port = transport.GRPCPort
if trans == TransportGRPC { case transport.HTTPS:
port = GRPCPort port = transport.HTTPSPort
}
if trans == TransportHTTPS {
port = HTTPSPort
} }
} }
@ -103,14 +72,6 @@ func SplitProtocolHostPort(address string) (protocol string, ip string, port str
} }
} }
// Supported transports.
const (
TransportDNS = "dns"
TransportTLS = "tls"
TransportGRPC = "grpc"
TransportHTTPS = "https"
)
type zoneOverlap struct { type zoneOverlap struct {
registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key
unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAdddr is registered by its original key unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAdddr is registered by its original key

View file

@ -192,21 +192,3 @@ func TestOverlapAddressChecker(t *testing.T) {
} }
} }
} }
func TestTransport(t *testing.T) {
for i, test := range []struct {
input string
expected string
}{
{"dns://.:53", TransportDNS},
{"2003::1/64.:53", TransportDNS},
{"grpc://example.org:1443 ", TransportGRPC},
{"tls://example.org ", TransportTLS},
{"https://example.org ", TransportHTTPS},
} {
actual := Transport(test.input)
if actual != test.expected {
t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual)
}
}
}

View file

@ -9,6 +9,7 @@ import (
"github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile" "github.com/mholt/caddy/caddyfile"
@ -111,29 +112,29 @@ func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
var servers []caddy.Server var servers []caddy.Server
for addr, group := range groups { for addr, group := range groups {
// switch on addr // switch on addr
switch Transport(addr) { switch tr, _ := transport.Parse(addr); tr {
case TransportDNS: case transport.DNS:
s, err := NewServer(addr, group) s, err := NewServer(addr, group)
if err != nil { if err != nil {
return nil, err return nil, err
} }
servers = append(servers, s) servers = append(servers, s)
case TransportTLS: case transport.TLS:
s, err := NewServerTLS(addr, group) s, err := NewServerTLS(addr, group)
if err != nil { if err != nil {
return nil, err return nil, err
} }
servers = append(servers, s) servers = append(servers, s)
case TransportGRPC: case transport.GRPC:
s, err := NewServergRPC(addr, group) s, err := NewServergRPC(addr, group)
if err != nil { if err != nil {
return nil, err return nil, err
} }
servers = append(servers, s) servers = append(servers, s)
case TransportHTTPS: case transport.HTTPS:
s, err := NewServerHTTPS(addr, group) s, err := NewServerHTTPS(addr, group)
if err != nil { if err != nil {
return nil, err return nil, err
@ -234,16 +235,8 @@ func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) {
return groups, nil return groups, nil
} }
const ( // DefaultPort is the default port.
// DefaultPort is the default port. const DefaultPort = "53"
DefaultPort = "53"
// TLSPort is the default port for DNS-over-TLS.
TLSPort = "853"
// GRPCPort is the default port for DNS-over-gRPC.
GRPCPort = "443"
// HTTPSPort is the default port for DNS-over-HTTPS.
HTTPSPort = "443"
)
// These "soft defaults" are configurable by // These "soft defaults" are configurable by
// command line flags, etc. // command line flags, etc.

View file

@ -15,6 +15,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/log" "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/rcode" "github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/coredns/coredns/plugin/pkg/trace" "github.com/coredns/coredns/plugin/pkg/trace"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -134,7 +135,7 @@ func (s *Server) ServePacket(p net.PacketConn) error {
// Listen implements caddy.TCPServer interface. // Listen implements caddy.TCPServer interface.
func (s *Server) Listen() (net.Listener, error) { func (s *Server) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportDNS+"://"):]) l, err := net.Listen("tcp", s.Addr[len(transport.DNS+"://"):])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -143,7 +144,7 @@ func (s *Server) Listen() (net.Listener, error) {
// ListenPacket implements caddy.UDPServer interface. // ListenPacket implements caddy.UDPServer interface.
func (s *Server) ListenPacket() (net.PacketConn, error) { func (s *Server) ListenPacket() (net.PacketConn, error) {
p, err := net.ListenPacket("udp", s.Addr[len(TransportDNS+"://"):]) p, err := net.ListenPacket("udp", s.Addr[len(transport.DNS+"://"):])
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -8,6 +8,7 @@ import (
"net" "net"
"github.com/coredns/coredns/pb" "github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/pkg/watch" "github.com/coredns/coredns/plugin/pkg/watch"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
@ -73,7 +74,7 @@ func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface. // Listen implements caddy.TCPServer interface.
func (s *ServergRPC) Listen() (net.Listener, error) { func (s *ServergRPC) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):]) l, err := net.Listen("tcp", s.Addr[len(transport.GRPC+"://"):])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -90,7 +91,7 @@ func (s *ServergRPC) OnStartupComplete() {
return return
} }
out := startUpZones(TransportGRPC+"://", s.Addr, s.zones) out := startUpZones(transport.GRPC+"://", s.Addr, s.zones)
if out != "" { if out != "" {
fmt.Print(out) fmt.Print(out)
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/doh" "github.com/coredns/coredns/plugin/pkg/doh"
"github.com/coredns/coredns/plugin/pkg/response" "github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/plugin/pkg/transport"
) )
// ServerHTTPS represents an instance of a DNS-over-HTTPS server. // ServerHTTPS represents an instance of a DNS-over-HTTPS server.
@ -60,7 +61,7 @@ func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface. // Listen implements caddy.TCPServer interface.
func (s *ServerHTTPS) Listen() (net.Listener, error) { func (s *ServerHTTPS) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportHTTPS+"://"):]) l, err := net.Listen("tcp", s.Addr[len(transport.HTTPS+"://"):])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -77,7 +78,7 @@ func (s *ServerHTTPS) OnStartupComplete() {
return return
} }
out := startUpZones(TransportHTTPS+"://", s.Addr, s.zones) out := startUpZones(transport.HTTPS+"://", s.Addr, s.zones)
if out != "" { if out != "" {
fmt.Print(out) fmt.Print(out)
} }

View file

@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -55,7 +57,7 @@ func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface. // Listen implements caddy.TCPServer interface.
func (s *ServerTLS) Listen() (net.Listener, error) { func (s *ServerTLS) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportTLS+"://"):]) l, err := net.Listen("tcp", s.Addr[len(transport.TLS+"://"):])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -72,7 +74,7 @@ func (s *ServerTLS) OnStartupComplete() {
return return
} }
out := startUpZones(TransportTLS+"://", s.Addr, s.zones) out := startUpZones(transport.TLS+"://", s.Addr, s.zones)
if out != "" { if out != "" {
fmt.Print(out) fmt.Print(out)
} }

View file

@ -35,16 +35,16 @@ func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight in
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
} }
func (t *transport) dialTimeout() time.Duration { func (t *Transport) dialTimeout() time.Duration {
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
} }
func (t *transport) updateDialTimeout(newDialTime time.Duration) { func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
} }
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. // Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) {
// If tls has been configured; use it. // If tls has been configured; use it.
if t.tlsConfig != nil { if t.tlsConfig != nil {
proto = "tcp-tls" proto = "tcp-tls"

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
@ -19,7 +20,7 @@ func TestForward(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -51,7 +52,7 @@ func TestForwardRefused(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()

View file

@ -5,6 +5,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -17,10 +19,10 @@ type HealthChecker interface {
// dnsHc is a health checker for a DNS endpoint (DNS, and DoT). // dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
type dnsHc struct{ c *dns.Client } type dnsHc struct{ c *dns.Client }
// NewHealthChecker returns a new HealthChecker based on protocol. // NewHealthChecker returns a new HealthChecker based on transport.
func NewHealthChecker(protocol int) HealthChecker { func NewHealthChecker(trans string) HealthChecker {
switch protocol { switch trans {
case DNS, TLS: case transport.DNS, transport.TLS:
c := new(dns.Client) c := new(dns.Client)
c.Net = "udp" c.Net = "udp"
c.ReadTimeout = 1 * time.Second c.ReadTimeout = 1 * time.Second

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -25,7 +26,7 @@ func TestHealth(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -65,7 +66,7 @@ func TestHealthTimeout(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -109,7 +110,7 @@ func TestHealthFailTwice(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -132,7 +133,7 @@ func TestHealthMaxFails(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.maxfails = 2 f.maxfails = 2
f.SetProxy(p) f.SetProxy(p)
@ -163,7 +164,7 @@ func TestHealthNoMaxFails(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.maxfails = 0 f.maxfails = 0
f.SetProxy(p) f.SetProxy(p)

View file

@ -7,6 +7,7 @@ package forward
import ( import (
"context" "context"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -81,7 +82,7 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M
func NewLookup(addr []string) *Forward { func NewLookup(addr []string) *Forward {
f := New() f := New()
for i := range addr { for i := range addr {
p := NewProxy(addr[i], DNS) p := NewProxy(addr[i], transport.DNS)
f.SetProxy(p) f.SetProxy(p)
} }
return f return f

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
@ -19,7 +20,7 @@ func TestLookup(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()

View file

@ -15,10 +15,10 @@ type persistConn struct {
used time.Time used time.Time
} }
// transport hold the persistent cache. // Transport hold the persistent cache.
type transport struct { type Transport struct {
avgDialTime int64 // kind of average time of dial time avgDialTime int64 // kind of average time of dial time
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
expire time.Duration // After this duration a connection is expired. expire time.Duration // After this duration a connection is expired.
addr string addr string
tlsConfig *tls.Config tlsConfig *tls.Config
@ -29,8 +29,8 @@ type transport struct {
stop chan bool stop chan bool
} }
func newTransport(addr string) *transport { func newTransport(addr string) *Transport {
t := &transport{ t := &Transport{
avgDialTime: int64(defaultDialTimeout / 2), avgDialTime: int64(defaultDialTimeout / 2),
conns: make(map[string][]*persistConn), conns: make(map[string][]*persistConn),
expire: defaultExpire, expire: defaultExpire,
@ -45,7 +45,7 @@ func newTransport(addr string) *transport {
// len returns the number of connection, used for metrics. Can only be safely // len returns the number of connection, used for metrics. Can only be safely
// used inside connManager() because of data races. // used inside connManager() because of data races.
func (t *transport) len() int { func (t *Transport) len() int {
l := 0 l := 0
for _, conns := range t.conns { for _, conns := range t.conns {
l += len(conns) l += len(conns)
@ -54,7 +54,7 @@ func (t *transport) len() int {
} }
// connManagers manages the persistent connection cache for UDP and TCP. // connManagers manages the persistent connection cache for UDP and TCP.
func (t *transport) connManager() { func (t *Transport) connManager() {
ticker := time.NewTicker(t.expire) ticker := time.NewTicker(t.expire)
Wait: Wait:
for { for {
@ -115,7 +115,7 @@ func closeConns(conns []*persistConn) {
} }
// cleanup removes connections from cache. // cleanup removes connections from cache.
func (t *transport) cleanup(all bool) { func (t *Transport) cleanup(all bool) {
staleTime := time.Now().Add(-t.expire) staleTime := time.Now().Add(-t.expire)
for proto, stack := range t.conns { for proto, stack := range t.conns {
if len(stack) == 0 { if len(stack) == 0 {
@ -144,19 +144,19 @@ func (t *transport) cleanup(all bool) {
} }
// Yield return the connection to transport for reuse. // Yield return the connection to transport for reuse.
func (t *transport) Yield(c *dns.Conn) { t.yield <- c } func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
// Start starts the transport's connection manager. // Start starts the transport's connection manager.
func (t *transport) Start() { go t.connManager() } func (t *Transport) Start() { go t.connManager() }
// Stop stops the transport's connection manager. // Stop stops the transport's connection manager.
func (t *transport) Stop() { close(t.stop) } func (t *Transport) Stop() { close(t.stop) }
// SetExpire sets the connection expire time in transport. // SetExpire sets the connection expire time in transport.
func (t *transport) SetExpire(expire time.Duration) { t.expire = expire } func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
// SetTLSConfig sets the TLS config in transport. // SetTLSConfig sets the TLS config in transport.
func (t *transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
const ( const (
defaultExpire = 10 * time.Second defaultExpire = 10 * time.Second

View file

@ -1,30 +0,0 @@
package forward
// Copied from coredns/core/dnsserver/address.go
import (
"strings"
)
// protocol returns the protocol of the string s. The second string returns s
// with the prefix chopped off.
func protocol(s string) (int, string) {
switch {
case strings.HasPrefix(s, _tls+"://"):
return TLS, s[len(_tls)+3:]
case strings.HasPrefix(s, _dns+"://"):
return DNS, s[len(_dns)+3:]
}
return DNS, s
}
// Supported protocols.
const (
DNS = iota + 1
TLS
)
const (
_dns = "dns"
_tls = "tls"
)

View file

@ -18,7 +18,7 @@ type Proxy struct {
// Connection caching // Connection caching
expire time.Duration expire time.Duration
transport *transport transport *Transport
// health checking // health checking
probe *up.Probe probe *up.Probe
@ -26,7 +26,7 @@ type Proxy struct {
} }
// NewProxy returns a new proxy. // NewProxy returns a new proxy.
func NewProxy(addr string, protocol int) *Proxy { func NewProxy(addr, trans string) *Proxy {
p := &Proxy{ p := &Proxy{
addr: addr, addr: addr,
fails: 0, fails: 0,
@ -34,7 +34,7 @@ func NewProxy(addr string, protocol int) *Proxy {
transport: newTransport(addr), transport: newTransport(addr),
avgRtt: int64(maxTimeout / 2), avgRtt: int64(maxTimeout / 2),
} }
p.health = NewHealthChecker(protocol) p.health = NewHealthChecker(trans)
runtime.SetFinalizer(p, (*Proxy).finalizer) runtime.SetFinalizer(p, (*Proxy).finalizer)
return p return p
} }

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
@ -26,7 +27,7 @@ func TestProxyClose(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
p.start(hcInterval) p.start(hcInterval)
go func() { p.Connect(ctx, state, options{}) }() go func() { p.Connect(ctx, state, options{}) }()
@ -95,7 +96,7 @@ func TestProxyTLSFail(t *testing.T) {
} }
func TestProtocolSelection(t *testing.T) { func TestProtocolSelection(t *testing.T) {
p := NewProxy("bad_address", DNS) p := NewProxy("bad_address", transport.DNS)
stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)}

View file

@ -2,7 +2,6 @@ package forward
import ( import (
"fmt" "fmt"
"net"
"strconv" "strconv"
"time" "time"
@ -11,6 +10,7 @@ import (
"github.com/coredns/coredns/plugin/metrics" "github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/coredns/coredns/plugin/pkg/dnsutil"
pkgtls "github.com/coredns/coredns/plugin/pkg/tls" pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile" "github.com/mholt/caddy/caddyfile"
@ -93,8 +93,6 @@ func parseForward(c *caddy.Controller) (*Forward, error) {
func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) { func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
f := New() f := New()
protocols := map[int]int{}
if !c.Args(&f.from) { if !c.Args(&f.from) {
return f, c.ArgErr() return f, c.ArgErr()
} }
@ -105,41 +103,17 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
return f, c.ArgErr() return f, c.ArgErr()
} }
// A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies.
protocols = make(map[int]int)
for i := range to {
protocols[i], to[i] = protocol(to[i])
}
// If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make
// any sense anymore... For now: lets don't care.
toHosts, err := dnsutil.ParseHostPortOrFile(to...) toHosts, err := dnsutil.ParseHostPortOrFile(to...)
if err != nil { if err != nil {
return f, err return f, err
} }
for i, h := range toHosts { transports := make([]string, len(toHosts))
// Double check the port, if e.g. is 53 and the transport is TLS make it 853. for i, host := range toHosts {
// This can be somewhat annoying because you *can't* have TLS on port 53 then. trans, h := transport.Parse(host)
switch protocols[i] { p := NewProxy(h, trans)
case TLS:
h1, p, err := net.SplitHostPort(h)
if err != nil {
break
}
// This is more of a bug in dnsutil.ParseHostPortOrFile that defaults to
// 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence
// Fix the port number here, back to what the user intended.
if p == "53" {
h = net.JoinHostPort(h1, "853")
}
}
// We can't set tlsConfig here, because we haven't parsed it yet.
// We set it below at the end of parseBlock, use nil now.
p := NewProxy(h, protocols[i])
f.proxies = append(f.proxies, p) f.proxies = append(f.proxies, p)
transports[i] = trans
} }
for c.NextBlock() { for c.NextBlock() {
@ -153,7 +127,7 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
} }
for i := range f.proxies { for i := range f.proxies {
// Only set this for proxies that need it. // Only set this for proxies that need it.
if protocols[i] == TLS { if transports[i] == transport.TLS {
f.proxies[i].SetTLSConfig(f.tlsConfig) f.proxies[i].SetTLSConfig(f.tlsConfig)
} }
f.proxies[i].SetExpire(f.expire) f.proxies[i].SetExpire(f.expire)

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
@ -34,7 +35,7 @@ func TestLookupTruncated(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -88,9 +89,9 @@ func TestForwardTruncated(t *testing.T) {
f := New() f := New()
p1 := NewProxy(s.Addr, DNS) p1 := NewProxy(s.Addr, transport.DNS)
f.SetProxy(p1) f.SetProxy(p1)
p2 := NewProxy(s.Addr, DNS) p2 := NewProxy(s.Addr, transport.DNS)
f.SetProxy(p2) f.SetProxy(p2)
defer f.Close() defer f.Close()

View file

@ -95,7 +95,7 @@ kubernetes [ZONES...] {
(only `to` is allow). **ADDRESS** must be denoted in CIDR notation (127.0.0.1/32 etc.) or just as (only `to` is allow). **ADDRESS** must be denoted in CIDR notation (127.0.0.1/32 etc.) or just as
plain addresses. The special wildcard `*` means: the entire internet. plain addresses. The special wildcard `*` means: the entire internet.
Sending DNS notifies is not supported. Sending DNS notifies is not supported.
[Deprecated](https://github.com/kubernetes/dns/blob/master/docs/specification.md#26---deprecated-records) pod records in the sub domain `pod.cluster.local` are not transferred. [Deprecated](https://github.com/kubernetes/dns/blob/master/docs/specification.md#26---deprecated-records) pod records in the sub domain `pod.cluster.local` are not transferred.
* `fallthrough` **[ZONES...]** If a query for a record in the zones for which the plugin is authoritative * `fallthrough` **[ZONES...]** If a query for a record in the zones for which the plugin is authoritative
results in NXDOMAIN, normally that is what the response will be. However, if you specify this option, results in NXDOMAIN, normally that is what the response will be. However, if you specify this option,
the query will instead be passed on down the plugin chain, which can include another plugin to handle the query will instead be passed on down the plugin chain, which can include another plugin to handle

View file

@ -6,6 +6,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -61,22 +63,10 @@ type (
// Normalize will return the host portion of host, stripping // Normalize will return the host portion of host, stripping
// of any port or transport. The host will also be fully qualified and lowercased. // of any port or transport. The host will also be fully qualified and lowercased.
func (h Host) Normalize() string { func (h Host) Normalize() string {
s := string(h) s := string(h)
_, s = transport.Parse(s)
switch { // The error can be ignore here, because this function is called after the corefile has already been vetted.
case strings.HasPrefix(s, TransportTLS+"://"):
s = s[len(TransportTLS+"://"):]
case strings.HasPrefix(s, TransportDNS+"://"):
s = s[len(TransportDNS+"://"):]
case strings.HasPrefix(s, TransportGRPC+"://"):
s = s[len(TransportGRPC+"://"):]
case strings.HasPrefix(s, TransportHTTPS+"://"):
s = s[len(TransportHTTPS+"://"):]
}
// The error can be ignore here, because this function is called after the corefile
// has already been vetted.
host, _, _, _ := SplitHostPort(s) host, _, _, _ := SplitHostPort(s)
return Name(host).Normalize() return Name(host).Normalize()
} }
@ -138,11 +128,3 @@ func SplitHostPort(s string) (host, port string, ipnet *net.IPNet, err error) {
} }
return host, port, n, nil return host, port, n, nil
} }
// Duplicated from core/dnsserver/address.go !
const (
TransportDNS = "dns"
TransportTLS = "tls"
TransportGRPC = "grpc"
TransportHTTPS = "https"
)

View file

@ -5,15 +5,21 @@ import (
"net" "net"
"os" "os"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// ParseHostPortOrFile parses the strings in s, each string can either be a address, // ParseHostPortOrFile parses the strings in s, each string can either be a
// address:port or a filename. The address part is checked and the filename case a // address, [scheme://]address:port or a filename. The address part is checked
// resolv.conf like file is parsed and the nameserver found are returned. // and in case of filename a resolv.conf like file is (assumed) and parsed and
// the nameservers found are returned.
func ParseHostPortOrFile(s ...string) ([]string, error) { func ParseHostPortOrFile(s ...string) ([]string, error) {
var servers []string var servers []string
for _, host := range s { for _, h := range s {
trans, host := transport.Parse(h)
addr, _, err := net.SplitHostPort(host) addr, _, err := net.SplitHostPort(host)
if err != nil { if err != nil {
// Parse didn't work, it is not a addr:port combo // Parse didn't work, it is not a addr:port combo
@ -26,13 +32,23 @@ func ParseHostPortOrFile(s ...string) ([]string, error) {
} }
return servers, fmt.Errorf("not an IP address or file: %q", host) return servers, fmt.Errorf("not an IP address or file: %q", host)
} }
ss := net.JoinHostPort(host, "53") var ss string
switch trans {
case transport.DNS:
ss = net.JoinHostPort(host, "53")
case transport.TLS:
ss = transport.TLS + "://" + net.JoinHostPort(host, transport.TLSPort)
case transport.GRPC:
ss = transport.GRPC + "://" + net.JoinHostPort(host, transport.GRPCPort)
case transport.HTTPS:
ss = transport.HTTPS + "://" + net.JoinHostPort(host, transport.HTTPSPort)
}
servers = append(servers, ss) servers = append(servers, ss)
continue continue
} }
if net.ParseIP(addr) == nil { if net.ParseIP(addr) == nil {
// No an IP address. // Not an IP address.
ss, err := tryFile(host) ss, err := tryFile(host)
if err == nil { if err == nil {
servers = append(servers, ss...) servers = append(servers, ss...)
@ -40,7 +56,7 @@ func ParseHostPortOrFile(s ...string) ([]string, error) {
} }
return servers, fmt.Errorf("not an IP address or file: %q", host) return servers, fmt.Errorf("not an IP address or file: %q", host)
} }
servers = append(servers, host) servers = append(servers, h)
} }
return servers, nil return servers, nil
} }

View file

@ -0,0 +1,49 @@
package transport
import (
"strings"
)
// Parse returns the transport defined in s and a string where the
// transport prefix is removed (if there was any). If no transport is defined
// we default to TransportDNS
func Parse(s string) (transport string, addr string) {
switch {
case strings.HasPrefix(s, TLS+"://"):
s = s[len(TLS+"://"):]
return TLS, s
case strings.HasPrefix(s, DNS+"://"):
s = s[len(DNS+"://"):]
return DNS, s
case strings.HasPrefix(s, GRPC+"://"):
s = s[len(GRPC+"://"):]
return GRPC, s
case strings.HasPrefix(s, HTTPS+"://"):
s = s[len(HTTPS+"://"):]
return HTTPS, s
}
return DNS, s
}
// Supported transports.
const (
DNS = "dns"
TLS = "tls"
GRPC = "grpc"
HTTPS = "https"
)
// Port numbers for the various protocols
const (
// TLSPort is the default port for DNS-over-TLS.
TLSPort = "853"
// GRPCPort is the default port for DNS-over-gRPC.
GRPCPort = "443"
// HTTPSPort is the default port for DNS-over-HTTPS.
HTTPSPort = "443"
)

View file

@ -0,0 +1,21 @@
package transport
import "testing"
func TestParse(t *testing.T) {
for i, test := range []struct {
input string
expected string
}{
{"dns://.:53", DNS},
{"2003::1/64.:53", DNS},
{"grpc://example.org:1443 ", GRPC},
{"tls://example.org ", TLS},
{"https://example.org ", HTTPS},
} {
actual, _ := Parse(test.input)
if actual != test.expected {
t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual)
}
}
}