Cleanup ParseHostOrFile (#2100)

Create plugin/pkg/transport that holds the transport related functions.
This needed to be a new pkg to prevent cyclic import errors.

This cleans up a bunch of duplicated code in core/dnsserver that also
tried to parse a transport (now all done in transport.Parse).

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben 2018-09-19 07:29:37 +01:00 committed by GitHub
parent 2f1223c36a
commit c349446a23
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 182 additions and 221 deletions

View file

@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -27,43 +28,13 @@ func (z zoneAddr) String() string {
return s return s
} }
// Transport returns the protocol of the string s
func Transport(s string) string {
switch {
case strings.HasPrefix(s, TransportTLS+"://"):
return TransportTLS
case strings.HasPrefix(s, TransportDNS+"://"):
return TransportDNS
case strings.HasPrefix(s, TransportGRPC+"://"):
return TransportGRPC
case strings.HasPrefix(s, TransportHTTPS+"://"):
return TransportHTTPS
}
return TransportDNS
}
// normalizeZone parses an zone string into a structured format with separate // normalizeZone parses an zone string into a structured format with separate
// host, and port portions, as well as the original input string. // host, and port portions, as well as the original input string.
func normalizeZone(str string) (zoneAddr, error) { func normalizeZone(str string) (zoneAddr, error) {
var err error var err error
// Default to DNS if there isn't a transport protocol prefix. var trans string
trans := TransportDNS trans, str = transport.Parse(str)
switch {
case strings.HasPrefix(str, TransportTLS+"://"):
trans = TransportTLS
str = str[len(TransportTLS+"://"):]
case strings.HasPrefix(str, TransportDNS+"://"):
trans = TransportDNS
str = str[len(TransportDNS+"://"):]
case strings.HasPrefix(str, TransportGRPC+"://"):
trans = TransportGRPC
str = str[len(TransportGRPC+"://"):]
case strings.HasPrefix(str, TransportHTTPS+"://"):
trans = TransportHTTPS
str = str[len(TransportHTTPS+"://"):]
}
host, port, ipnet, err := plugin.SplitHostPort(str) host, port, ipnet, err := plugin.SplitHostPort(str)
if err != nil { if err != nil {
@ -71,17 +42,15 @@ func normalizeZone(str string) (zoneAddr, error) {
} }
if port == "" { if port == "" {
if trans == TransportDNS { switch trans {
case transport.DNS:
port = Port port = Port
} case transport.TLS:
if trans == TransportTLS { port = transport.TLSPort
port = TLSPort case transport.GRPC:
} port = transport.GRPCPort
if trans == TransportGRPC { case transport.HTTPS:
port = GRPCPort port = transport.HTTPSPort
}
if trans == TransportHTTPS {
port = HTTPSPort
} }
} }
@ -103,14 +72,6 @@ func SplitProtocolHostPort(address string) (protocol string, ip string, port str
} }
} }
// Supported transports.
const (
TransportDNS = "dns"
TransportTLS = "tls"
TransportGRPC = "grpc"
TransportHTTPS = "https"
)
type zoneOverlap struct { type zoneOverlap struct {
registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key
unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAdddr is registered by its original key unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAdddr is registered by its original key

View file

@ -192,21 +192,3 @@ func TestOverlapAddressChecker(t *testing.T) {
} }
} }
} }
func TestTransport(t *testing.T) {
for i, test := range []struct {
input string
expected string
}{
{"dns://.:53", TransportDNS},
{"2003::1/64.:53", TransportDNS},
{"grpc://example.org:1443 ", TransportGRPC},
{"tls://example.org ", TransportTLS},
{"https://example.org ", TransportHTTPS},
} {
actual := Transport(test.input)
if actual != test.expected {
t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual)
}
}
}

View file

@ -9,6 +9,7 @@ import (
"github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile" "github.com/mholt/caddy/caddyfile"
@ -111,29 +112,29 @@ func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
var servers []caddy.Server var servers []caddy.Server
for addr, group := range groups { for addr, group := range groups {
// switch on addr // switch on addr
switch Transport(addr) { switch tr, _ := transport.Parse(addr); tr {
case TransportDNS: case transport.DNS:
s, err := NewServer(addr, group) s, err := NewServer(addr, group)
if err != nil { if err != nil {
return nil, err return nil, err
} }
servers = append(servers, s) servers = append(servers, s)
case TransportTLS: case transport.TLS:
s, err := NewServerTLS(addr, group) s, err := NewServerTLS(addr, group)
if err != nil { if err != nil {
return nil, err return nil, err
} }
servers = append(servers, s) servers = append(servers, s)
case TransportGRPC: case transport.GRPC:
s, err := NewServergRPC(addr, group) s, err := NewServergRPC(addr, group)
if err != nil { if err != nil {
return nil, err return nil, err
} }
servers = append(servers, s) servers = append(servers, s)
case TransportHTTPS: case transport.HTTPS:
s, err := NewServerHTTPS(addr, group) s, err := NewServerHTTPS(addr, group)
if err != nil { if err != nil {
return nil, err return nil, err
@ -234,16 +235,8 @@ func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) {
return groups, nil return groups, nil
} }
const ( // DefaultPort is the default port.
// DefaultPort is the default port. const DefaultPort = "53"
DefaultPort = "53"
// TLSPort is the default port for DNS-over-TLS.
TLSPort = "853"
// GRPCPort is the default port for DNS-over-gRPC.
GRPCPort = "443"
// HTTPSPort is the default port for DNS-over-HTTPS.
HTTPSPort = "443"
)
// These "soft defaults" are configurable by // These "soft defaults" are configurable by
// command line flags, etc. // command line flags, etc.

View file

@ -15,6 +15,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/log" "github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/rcode" "github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/coredns/coredns/plugin/pkg/trace" "github.com/coredns/coredns/plugin/pkg/trace"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -134,7 +135,7 @@ func (s *Server) ServePacket(p net.PacketConn) error {
// Listen implements caddy.TCPServer interface. // Listen implements caddy.TCPServer interface.
func (s *Server) Listen() (net.Listener, error) { func (s *Server) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportDNS+"://"):]) l, err := net.Listen("tcp", s.Addr[len(transport.DNS+"://"):])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -143,7 +144,7 @@ func (s *Server) Listen() (net.Listener, error) {
// ListenPacket implements caddy.UDPServer interface. // ListenPacket implements caddy.UDPServer interface.
func (s *Server) ListenPacket() (net.PacketConn, error) { func (s *Server) ListenPacket() (net.PacketConn, error) {
p, err := net.ListenPacket("udp", s.Addr[len(TransportDNS+"://"):]) p, err := net.ListenPacket("udp", s.Addr[len(transport.DNS+"://"):])
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -8,6 +8,7 @@ import (
"net" "net"
"github.com/coredns/coredns/pb" "github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/pkg/watch" "github.com/coredns/coredns/plugin/pkg/watch"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
@ -73,7 +74,7 @@ func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface. // Listen implements caddy.TCPServer interface.
func (s *ServergRPC) Listen() (net.Listener, error) { func (s *ServergRPC) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):]) l, err := net.Listen("tcp", s.Addr[len(transport.GRPC+"://"):])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -90,7 +91,7 @@ func (s *ServergRPC) OnStartupComplete() {
return return
} }
out := startUpZones(TransportGRPC+"://", s.Addr, s.zones) out := startUpZones(transport.GRPC+"://", s.Addr, s.zones)
if out != "" { if out != "" {
fmt.Print(out) fmt.Print(out)
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/doh" "github.com/coredns/coredns/plugin/pkg/doh"
"github.com/coredns/coredns/plugin/pkg/response" "github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/plugin/pkg/transport"
) )
// ServerHTTPS represents an instance of a DNS-over-HTTPS server. // ServerHTTPS represents an instance of a DNS-over-HTTPS server.
@ -60,7 +61,7 @@ func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface. // Listen implements caddy.TCPServer interface.
func (s *ServerHTTPS) Listen() (net.Listener, error) { func (s *ServerHTTPS) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportHTTPS+"://"):]) l, err := net.Listen("tcp", s.Addr[len(transport.HTTPS+"://"):])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -77,7 +78,7 @@ func (s *ServerHTTPS) OnStartupComplete() {
return return
} }
out := startUpZones(TransportHTTPS+"://", s.Addr, s.zones) out := startUpZones(transport.HTTPS+"://", s.Addr, s.zones)
if out != "" { if out != "" {
fmt.Print(out) fmt.Print(out)
} }

View file

@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -55,7 +57,7 @@ func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface. // Listen implements caddy.TCPServer interface.
func (s *ServerTLS) Listen() (net.Listener, error) { func (s *ServerTLS) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportTLS+"://"):]) l, err := net.Listen("tcp", s.Addr[len(transport.TLS+"://"):])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -72,7 +74,7 @@ func (s *ServerTLS) OnStartupComplete() {
return return
} }
out := startUpZones(TransportTLS+"://", s.Addr, s.zones) out := startUpZones(transport.TLS+"://", s.Addr, s.zones)
if out != "" { if out != "" {
fmt.Print(out) fmt.Print(out)
} }

View file

@ -35,16 +35,16 @@ func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight in
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
} }
func (t *transport) dialTimeout() time.Duration { func (t *Transport) dialTimeout() time.Duration {
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
} }
func (t *transport) updateDialTimeout(newDialTime time.Duration) { func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
} }
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. // Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) {
// If tls has been configured; use it. // If tls has been configured; use it.
if t.tlsConfig != nil { if t.tlsConfig != nil {
proto = "tcp-tls" proto = "tcp-tls"

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
@ -19,7 +20,7 @@ func TestForward(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -51,7 +52,7 @@ func TestForwardRefused(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()

View file

@ -5,6 +5,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -17,10 +19,10 @@ type HealthChecker interface {
// dnsHc is a health checker for a DNS endpoint (DNS, and DoT). // dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
type dnsHc struct{ c *dns.Client } type dnsHc struct{ c *dns.Client }
// NewHealthChecker returns a new HealthChecker based on protocol. // NewHealthChecker returns a new HealthChecker based on transport.
func NewHealthChecker(protocol int) HealthChecker { func NewHealthChecker(trans string) HealthChecker {
switch protocol { switch trans {
case DNS, TLS: case transport.DNS, transport.TLS:
c := new(dns.Client) c := new(dns.Client)
c.Net = "udp" c.Net = "udp"
c.ReadTimeout = 1 * time.Second c.ReadTimeout = 1 * time.Second

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -25,7 +26,7 @@ func TestHealth(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -65,7 +66,7 @@ func TestHealthTimeout(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -109,7 +110,7 @@ func TestHealthFailTwice(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -132,7 +133,7 @@ func TestHealthMaxFails(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.maxfails = 2 f.maxfails = 2
f.SetProxy(p) f.SetProxy(p)
@ -163,7 +164,7 @@ func TestHealthNoMaxFails(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.maxfails = 0 f.maxfails = 0
f.SetProxy(p) f.SetProxy(p)

View file

@ -7,6 +7,7 @@ package forward
import ( import (
"context" "context"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -81,7 +82,7 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M
func NewLookup(addr []string) *Forward { func NewLookup(addr []string) *Forward {
f := New() f := New()
for i := range addr { for i := range addr {
p := NewProxy(addr[i], DNS) p := NewProxy(addr[i], transport.DNS)
f.SetProxy(p) f.SetProxy(p)
} }
return f return f

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
@ -19,7 +20,7 @@ func TestLookup(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()

View file

@ -15,10 +15,10 @@ type persistConn struct {
used time.Time used time.Time
} }
// transport hold the persistent cache. // Transport hold the persistent cache.
type transport struct { type Transport struct {
avgDialTime int64 // kind of average time of dial time avgDialTime int64 // kind of average time of dial time
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
expire time.Duration // After this duration a connection is expired. expire time.Duration // After this duration a connection is expired.
addr string addr string
tlsConfig *tls.Config tlsConfig *tls.Config
@ -29,8 +29,8 @@ type transport struct {
stop chan bool stop chan bool
} }
func newTransport(addr string) *transport { func newTransport(addr string) *Transport {
t := &transport{ t := &Transport{
avgDialTime: int64(defaultDialTimeout / 2), avgDialTime: int64(defaultDialTimeout / 2),
conns: make(map[string][]*persistConn), conns: make(map[string][]*persistConn),
expire: defaultExpire, expire: defaultExpire,
@ -45,7 +45,7 @@ func newTransport(addr string) *transport {
// len returns the number of connection, used for metrics. Can only be safely // len returns the number of connection, used for metrics. Can only be safely
// used inside connManager() because of data races. // used inside connManager() because of data races.
func (t *transport) len() int { func (t *Transport) len() int {
l := 0 l := 0
for _, conns := range t.conns { for _, conns := range t.conns {
l += len(conns) l += len(conns)
@ -54,7 +54,7 @@ func (t *transport) len() int {
} }
// connManagers manages the persistent connection cache for UDP and TCP. // connManagers manages the persistent connection cache for UDP and TCP.
func (t *transport) connManager() { func (t *Transport) connManager() {
ticker := time.NewTicker(t.expire) ticker := time.NewTicker(t.expire)
Wait: Wait:
for { for {
@ -115,7 +115,7 @@ func closeConns(conns []*persistConn) {
} }
// cleanup removes connections from cache. // cleanup removes connections from cache.
func (t *transport) cleanup(all bool) { func (t *Transport) cleanup(all bool) {
staleTime := time.Now().Add(-t.expire) staleTime := time.Now().Add(-t.expire)
for proto, stack := range t.conns { for proto, stack := range t.conns {
if len(stack) == 0 { if len(stack) == 0 {
@ -144,19 +144,19 @@ func (t *transport) cleanup(all bool) {
} }
// Yield return the connection to transport for reuse. // Yield return the connection to transport for reuse.
func (t *transport) Yield(c *dns.Conn) { t.yield <- c } func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
// Start starts the transport's connection manager. // Start starts the transport's connection manager.
func (t *transport) Start() { go t.connManager() } func (t *Transport) Start() { go t.connManager() }
// Stop stops the transport's connection manager. // Stop stops the transport's connection manager.
func (t *transport) Stop() { close(t.stop) } func (t *Transport) Stop() { close(t.stop) }
// SetExpire sets the connection expire time in transport. // SetExpire sets the connection expire time in transport.
func (t *transport) SetExpire(expire time.Duration) { t.expire = expire } func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
// SetTLSConfig sets the TLS config in transport. // SetTLSConfig sets the TLS config in transport.
func (t *transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
const ( const (
defaultExpire = 10 * time.Second defaultExpire = 10 * time.Second

View file

@ -1,30 +0,0 @@
package forward
// Copied from coredns/core/dnsserver/address.go
import (
"strings"
)
// protocol returns the protocol of the string s. The second string returns s
// with the prefix chopped off.
func protocol(s string) (int, string) {
switch {
case strings.HasPrefix(s, _tls+"://"):
return TLS, s[len(_tls)+3:]
case strings.HasPrefix(s, _dns+"://"):
return DNS, s[len(_dns)+3:]
}
return DNS, s
}
// Supported protocols.
const (
DNS = iota + 1
TLS
)
const (
_dns = "dns"
_tls = "tls"
)

View file

@ -18,7 +18,7 @@ type Proxy struct {
// Connection caching // Connection caching
expire time.Duration expire time.Duration
transport *transport transport *Transport
// health checking // health checking
probe *up.Probe probe *up.Probe
@ -26,7 +26,7 @@ type Proxy struct {
} }
// NewProxy returns a new proxy. // NewProxy returns a new proxy.
func NewProxy(addr string, protocol int) *Proxy { func NewProxy(addr, trans string) *Proxy {
p := &Proxy{ p := &Proxy{
addr: addr, addr: addr,
fails: 0, fails: 0,
@ -34,7 +34,7 @@ func NewProxy(addr string, protocol int) *Proxy {
transport: newTransport(addr), transport: newTransport(addr),
avgRtt: int64(maxTimeout / 2), avgRtt: int64(maxTimeout / 2),
} }
p.health = NewHealthChecker(protocol) p.health = NewHealthChecker(trans)
runtime.SetFinalizer(p, (*Proxy).finalizer) runtime.SetFinalizer(p, (*Proxy).finalizer)
return p return p
} }

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
@ -26,7 +27,7 @@ func TestProxyClose(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
p.start(hcInterval) p.start(hcInterval)
go func() { p.Connect(ctx, state, options{}) }() go func() { p.Connect(ctx, state, options{}) }()
@ -95,7 +96,7 @@ func TestProxyTLSFail(t *testing.T) {
} }
func TestProtocolSelection(t *testing.T) { func TestProtocolSelection(t *testing.T) {
p := NewProxy("bad_address", DNS) p := NewProxy("bad_address", transport.DNS)
stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)}

View file

@ -2,7 +2,6 @@ package forward
import ( import (
"fmt" "fmt"
"net"
"strconv" "strconv"
"time" "time"
@ -11,6 +10,7 @@ import (
"github.com/coredns/coredns/plugin/metrics" "github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/coredns/coredns/plugin/pkg/dnsutil"
pkgtls "github.com/coredns/coredns/plugin/pkg/tls" pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/mholt/caddy" "github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile" "github.com/mholt/caddy/caddyfile"
@ -93,8 +93,6 @@ func parseForward(c *caddy.Controller) (*Forward, error) {
func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) { func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
f := New() f := New()
protocols := map[int]int{}
if !c.Args(&f.from) { if !c.Args(&f.from) {
return f, c.ArgErr() return f, c.ArgErr()
} }
@ -105,41 +103,17 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
return f, c.ArgErr() return f, c.ArgErr()
} }
// A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies.
protocols = make(map[int]int)
for i := range to {
protocols[i], to[i] = protocol(to[i])
}
// If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make
// any sense anymore... For now: lets don't care.
toHosts, err := dnsutil.ParseHostPortOrFile(to...) toHosts, err := dnsutil.ParseHostPortOrFile(to...)
if err != nil { if err != nil {
return f, err return f, err
} }
for i, h := range toHosts { transports := make([]string, len(toHosts))
// Double check the port, if e.g. is 53 and the transport is TLS make it 853. for i, host := range toHosts {
// This can be somewhat annoying because you *can't* have TLS on port 53 then. trans, h := transport.Parse(host)
switch protocols[i] { p := NewProxy(h, trans)
case TLS:
h1, p, err := net.SplitHostPort(h)
if err != nil {
break
}
// This is more of a bug in dnsutil.ParseHostPortOrFile that defaults to
// 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence
// Fix the port number here, back to what the user intended.
if p == "53" {
h = net.JoinHostPort(h1, "853")
}
}
// We can't set tlsConfig here, because we haven't parsed it yet.
// We set it below at the end of parseBlock, use nil now.
p := NewProxy(h, protocols[i])
f.proxies = append(f.proxies, p) f.proxies = append(f.proxies, p)
transports[i] = trans
} }
for c.NextBlock() { for c.NextBlock() {
@ -153,7 +127,7 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
} }
for i := range f.proxies { for i := range f.proxies {
// Only set this for proxies that need it. // Only set this for proxies that need it.
if protocols[i] == TLS { if transports[i] == transport.TLS {
f.proxies[i].SetTLSConfig(f.tlsConfig) f.proxies[i].SetTLSConfig(f.tlsConfig)
} }
f.proxies[i].SetExpire(f.expire) f.proxies[i].SetExpire(f.expire)

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
@ -34,7 +35,7 @@ func TestLookupTruncated(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, DNS) p := NewProxy(s.Addr, transport.DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -88,9 +89,9 @@ func TestForwardTruncated(t *testing.T) {
f := New() f := New()
p1 := NewProxy(s.Addr, DNS) p1 := NewProxy(s.Addr, transport.DNS)
f.SetProxy(p1) f.SetProxy(p1)
p2 := NewProxy(s.Addr, DNS) p2 := NewProxy(s.Addr, transport.DNS)
f.SetProxy(p2) f.SetProxy(p2)
defer f.Close() defer f.Close()

View file

@ -6,6 +6,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -61,22 +63,10 @@ type (
// Normalize will return the host portion of host, stripping // Normalize will return the host portion of host, stripping
// of any port or transport. The host will also be fully qualified and lowercased. // of any port or transport. The host will also be fully qualified and lowercased.
func (h Host) Normalize() string { func (h Host) Normalize() string {
s := string(h) s := string(h)
_, s = transport.Parse(s)
switch { // The error can be ignore here, because this function is called after the corefile has already been vetted.
case strings.HasPrefix(s, TransportTLS+"://"):
s = s[len(TransportTLS+"://"):]
case strings.HasPrefix(s, TransportDNS+"://"):
s = s[len(TransportDNS+"://"):]
case strings.HasPrefix(s, TransportGRPC+"://"):
s = s[len(TransportGRPC+"://"):]
case strings.HasPrefix(s, TransportHTTPS+"://"):
s = s[len(TransportHTTPS+"://"):]
}
// The error can be ignore here, because this function is called after the corefile
// has already been vetted.
host, _, _, _ := SplitHostPort(s) host, _, _, _ := SplitHostPort(s)
return Name(host).Normalize() return Name(host).Normalize()
} }
@ -138,11 +128,3 @@ func SplitHostPort(s string) (host, port string, ipnet *net.IPNet, err error) {
} }
return host, port, n, nil return host, port, n, nil
} }
// Duplicated from core/dnsserver/address.go !
const (
TransportDNS = "dns"
TransportTLS = "tls"
TransportGRPC = "grpc"
TransportHTTPS = "https"
)

View file

@ -5,15 +5,21 @@ import (
"net" "net"
"os" "os"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// ParseHostPortOrFile parses the strings in s, each string can either be a address, // ParseHostPortOrFile parses the strings in s, each string can either be a
// address:port or a filename. The address part is checked and the filename case a // address, [scheme://]address:port or a filename. The address part is checked
// resolv.conf like file is parsed and the nameserver found are returned. // and in case of filename a resolv.conf like file is (assumed) and parsed and
// the nameservers found are returned.
func ParseHostPortOrFile(s ...string) ([]string, error) { func ParseHostPortOrFile(s ...string) ([]string, error) {
var servers []string var servers []string
for _, host := range s { for _, h := range s {
trans, host := transport.Parse(h)
addr, _, err := net.SplitHostPort(host) addr, _, err := net.SplitHostPort(host)
if err != nil { if err != nil {
// Parse didn't work, it is not a addr:port combo // Parse didn't work, it is not a addr:port combo
@ -26,13 +32,23 @@ func ParseHostPortOrFile(s ...string) ([]string, error) {
} }
return servers, fmt.Errorf("not an IP address or file: %q", host) return servers, fmt.Errorf("not an IP address or file: %q", host)
} }
ss := net.JoinHostPort(host, "53") var ss string
switch trans {
case transport.DNS:
ss = net.JoinHostPort(host, "53")
case transport.TLS:
ss = transport.TLS + "://" + net.JoinHostPort(host, transport.TLSPort)
case transport.GRPC:
ss = transport.GRPC + "://" + net.JoinHostPort(host, transport.GRPCPort)
case transport.HTTPS:
ss = transport.HTTPS + "://" + net.JoinHostPort(host, transport.HTTPSPort)
}
servers = append(servers, ss) servers = append(servers, ss)
continue continue
} }
if net.ParseIP(addr) == nil { if net.ParseIP(addr) == nil {
// No an IP address. // Not an IP address.
ss, err := tryFile(host) ss, err := tryFile(host)
if err == nil { if err == nil {
servers = append(servers, ss...) servers = append(servers, ss...)
@ -40,7 +56,7 @@ func ParseHostPortOrFile(s ...string) ([]string, error) {
} }
return servers, fmt.Errorf("not an IP address or file: %q", host) return servers, fmt.Errorf("not an IP address or file: %q", host)
} }
servers = append(servers, host) servers = append(servers, h)
} }
return servers, nil return servers, nil
} }

View file

@ -0,0 +1,49 @@
package transport
import (
"strings"
)
// Parse returns the transport defined in s and a string where the
// transport prefix is removed (if there was any). If no transport is defined
// we default to TransportDNS
func Parse(s string) (transport string, addr string) {
switch {
case strings.HasPrefix(s, TLS+"://"):
s = s[len(TLS+"://"):]
return TLS, s
case strings.HasPrefix(s, DNS+"://"):
s = s[len(DNS+"://"):]
return DNS, s
case strings.HasPrefix(s, GRPC+"://"):
s = s[len(GRPC+"://"):]
return GRPC, s
case strings.HasPrefix(s, HTTPS+"://"):
s = s[len(HTTPS+"://"):]
return HTTPS, s
}
return DNS, s
}
// Supported transports.
const (
DNS = "dns"
TLS = "tls"
GRPC = "grpc"
HTTPS = "https"
)
// Port numbers for the various protocols
const (
// TLSPort is the default port for DNS-over-TLS.
TLSPort = "853"
// GRPCPort is the default port for DNS-over-gRPC.
GRPCPort = "443"
// HTTPSPort is the default port for DNS-over-HTTPS.
HTTPSPort = "443"
)

View file

@ -0,0 +1,21 @@
package transport
import "testing"
func TestParse(t *testing.T) {
for i, test := range []struct {
input string
expected string
}{
{"dns://.:53", DNS},
{"2003::1/64.:53", DNS},
{"grpc://example.org:1443 ", GRPC},
{"tls://example.org ", TLS},
{"https://example.org ", HTTPS},
} {
actual, _ := Parse(test.input)
if actual != test.expected {
t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual)
}
}
}