Cleanup ParseHostOrFile (#2100)
Create plugin/pkg/transport that holds the transport related functions. This needed to be a new pkg to prevent cyclic import errors. This cleans up a bunch of duplicated code in core/dnsserver that also tried to parse a transport (now all done in transport.Parse). Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
parent
2f1223c36a
commit
c349446a23
24 changed files with 182 additions and 221 deletions
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
@ -27,43 +28,13 @@ func (z zoneAddr) String() string {
|
|||
return s
|
||||
}
|
||||
|
||||
// Transport returns the protocol of the string s
|
||||
func Transport(s string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(s, TransportTLS+"://"):
|
||||
return TransportTLS
|
||||
case strings.HasPrefix(s, TransportDNS+"://"):
|
||||
return TransportDNS
|
||||
case strings.HasPrefix(s, TransportGRPC+"://"):
|
||||
return TransportGRPC
|
||||
case strings.HasPrefix(s, TransportHTTPS+"://"):
|
||||
return TransportHTTPS
|
||||
}
|
||||
return TransportDNS
|
||||
}
|
||||
|
||||
// normalizeZone parses an zone string into a structured format with separate
|
||||
// host, and port portions, as well as the original input string.
|
||||
func normalizeZone(str string) (zoneAddr, error) {
|
||||
var err error
|
||||
|
||||
// Default to DNS if there isn't a transport protocol prefix.
|
||||
trans := TransportDNS
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(str, TransportTLS+"://"):
|
||||
trans = TransportTLS
|
||||
str = str[len(TransportTLS+"://"):]
|
||||
case strings.HasPrefix(str, TransportDNS+"://"):
|
||||
trans = TransportDNS
|
||||
str = str[len(TransportDNS+"://"):]
|
||||
case strings.HasPrefix(str, TransportGRPC+"://"):
|
||||
trans = TransportGRPC
|
||||
str = str[len(TransportGRPC+"://"):]
|
||||
case strings.HasPrefix(str, TransportHTTPS+"://"):
|
||||
trans = TransportHTTPS
|
||||
str = str[len(TransportHTTPS+"://"):]
|
||||
}
|
||||
var trans string
|
||||
trans, str = transport.Parse(str)
|
||||
|
||||
host, port, ipnet, err := plugin.SplitHostPort(str)
|
||||
if err != nil {
|
||||
|
@ -71,17 +42,15 @@ func normalizeZone(str string) (zoneAddr, error) {
|
|||
}
|
||||
|
||||
if port == "" {
|
||||
if trans == TransportDNS {
|
||||
switch trans {
|
||||
case transport.DNS:
|
||||
port = Port
|
||||
}
|
||||
if trans == TransportTLS {
|
||||
port = TLSPort
|
||||
}
|
||||
if trans == TransportGRPC {
|
||||
port = GRPCPort
|
||||
}
|
||||
if trans == TransportHTTPS {
|
||||
port = HTTPSPort
|
||||
case transport.TLS:
|
||||
port = transport.TLSPort
|
||||
case transport.GRPC:
|
||||
port = transport.GRPCPort
|
||||
case transport.HTTPS:
|
||||
port = transport.HTTPSPort
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -103,14 +72,6 @@ func SplitProtocolHostPort(address string) (protocol string, ip string, port str
|
|||
}
|
||||
}
|
||||
|
||||
// Supported transports.
|
||||
const (
|
||||
TransportDNS = "dns"
|
||||
TransportTLS = "tls"
|
||||
TransportGRPC = "grpc"
|
||||
TransportHTTPS = "https"
|
||||
)
|
||||
|
||||
type zoneOverlap struct {
|
||||
registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key
|
||||
unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAdddr is registered by its original key
|
||||
|
|
|
@ -192,21 +192,3 @@ func TestOverlapAddressChecker(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransport(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"dns://.:53", TransportDNS},
|
||||
{"2003::1/64.:53", TransportDNS},
|
||||
{"grpc://example.org:1443 ", TransportGRPC},
|
||||
{"tls://example.org ", TransportTLS},
|
||||
{"https://example.org ", TransportHTTPS},
|
||||
} {
|
||||
actual := Transport(test.input)
|
||||
if actual != test.expected {
|
||||
t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnsutil"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
|
@ -111,29 +112,29 @@ func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
|
|||
var servers []caddy.Server
|
||||
for addr, group := range groups {
|
||||
// switch on addr
|
||||
switch Transport(addr) {
|
||||
case TransportDNS:
|
||||
switch tr, _ := transport.Parse(addr); tr {
|
||||
case transport.DNS:
|
||||
s, err := NewServer(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
servers = append(servers, s)
|
||||
|
||||
case TransportTLS:
|
||||
case transport.TLS:
|
||||
s, err := NewServerTLS(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
servers = append(servers, s)
|
||||
|
||||
case TransportGRPC:
|
||||
case transport.GRPC:
|
||||
s, err := NewServergRPC(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
servers = append(servers, s)
|
||||
|
||||
case TransportHTTPS:
|
||||
case transport.HTTPS:
|
||||
s, err := NewServerHTTPS(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -234,16 +235,8 @@ func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) {
|
|||
return groups, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultPort is the default port.
|
||||
DefaultPort = "53"
|
||||
// TLSPort is the default port for DNS-over-TLS.
|
||||
TLSPort = "853"
|
||||
// GRPCPort is the default port for DNS-over-gRPC.
|
||||
GRPCPort = "443"
|
||||
// HTTPSPort is the default port for DNS-over-HTTPS.
|
||||
HTTPSPort = "443"
|
||||
)
|
||||
const DefaultPort = "53"
|
||||
|
||||
// These "soft defaults" are configurable by
|
||||
// command line flags, etc.
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/coredns/coredns/plugin/pkg/log"
|
||||
"github.com/coredns/coredns/plugin/pkg/rcode"
|
||||
"github.com/coredns/coredns/plugin/pkg/trace"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -134,7 +135,7 @@ func (s *Server) ServePacket(p net.PacketConn) error {
|
|||
|
||||
// Listen implements caddy.TCPServer interface.
|
||||
func (s *Server) Listen() (net.Listener, error) {
|
||||
l, err := net.Listen("tcp", s.Addr[len(TransportDNS+"://"):])
|
||||
l, err := net.Listen("tcp", s.Addr[len(transport.DNS+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -143,7 +144,7 @@ func (s *Server) Listen() (net.Listener, error) {
|
|||
|
||||
// ListenPacket implements caddy.UDPServer interface.
|
||||
func (s *Server) ListenPacket() (net.PacketConn, error) {
|
||||
p, err := net.ListenPacket("udp", s.Addr[len(TransportDNS+"://"):])
|
||||
p, err := net.ListenPacket("udp", s.Addr[len(transport.DNS+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net"
|
||||
|
||||
"github.com/coredns/coredns/pb"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/plugin/pkg/watch"
|
||||
|
||||
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
|
||||
|
@ -73,7 +74,7 @@ func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil }
|
|||
// Listen implements caddy.TCPServer interface.
|
||||
func (s *ServergRPC) Listen() (net.Listener, error) {
|
||||
|
||||
l, err := net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):])
|
||||
l, err := net.Listen("tcp", s.Addr[len(transport.GRPC+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -90,7 +91,7 @@ func (s *ServergRPC) OnStartupComplete() {
|
|||
return
|
||||
}
|
||||
|
||||
out := startUpZones(TransportGRPC+"://", s.Addr, s.zones)
|
||||
out := startUpZones(transport.GRPC+"://", s.Addr, s.zones)
|
||||
if out != "" {
|
||||
fmt.Print(out)
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/coredns/coredns/plugin/pkg/dnsutil"
|
||||
"github.com/coredns/coredns/plugin/pkg/doh"
|
||||
"github.com/coredns/coredns/plugin/pkg/response"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
)
|
||||
|
||||
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
|
||||
|
@ -60,7 +61,7 @@ func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil }
|
|||
// Listen implements caddy.TCPServer interface.
|
||||
func (s *ServerHTTPS) Listen() (net.Listener, error) {
|
||||
|
||||
l, err := net.Listen("tcp", s.Addr[len(TransportHTTPS+"://"):])
|
||||
l, err := net.Listen("tcp", s.Addr[len(transport.HTTPS+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -77,7 +78,7 @@ func (s *ServerHTTPS) OnStartupComplete() {
|
|||
return
|
||||
}
|
||||
|
||||
out := startUpZones(TransportHTTPS+"://", s.Addr, s.zones)
|
||||
out := startUpZones(transport.HTTPS+"://", s.Addr, s.zones)
|
||||
if out != "" {
|
||||
fmt.Print(out)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -55,7 +57,7 @@ func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil }
|
|||
|
||||
// Listen implements caddy.TCPServer interface.
|
||||
func (s *ServerTLS) Listen() (net.Listener, error) {
|
||||
l, err := net.Listen("tcp", s.Addr[len(TransportTLS+"://"):])
|
||||
l, err := net.Listen("tcp", s.Addr[len(transport.TLS+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -72,7 +74,7 @@ func (s *ServerTLS) OnStartupComplete() {
|
|||
return
|
||||
}
|
||||
|
||||
out := startUpZones(TransportTLS+"://", s.Addr, s.zones)
|
||||
out := startUpZones(transport.TLS+"://", s.Addr, s.zones)
|
||||
if out != "" {
|
||||
fmt.Print(out)
|
||||
}
|
||||
|
|
|
@ -35,16 +35,16 @@ func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight in
|
|||
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
|
||||
}
|
||||
|
||||
func (t *transport) dialTimeout() time.Duration {
|
||||
func (t *Transport) dialTimeout() time.Duration {
|
||||
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
|
||||
}
|
||||
|
||||
func (t *transport) updateDialTimeout(newDialTime time.Duration) {
|
||||
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
|
||||
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
|
||||
}
|
||||
|
||||
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
|
||||
func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
|
||||
func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) {
|
||||
// If tls has been configured; use it.
|
||||
if t.tlsConfig != nil {
|
||||
proto = "tcp-tls"
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
|
@ -19,7 +20,7 @@ func TestForward(t *testing.T) {
|
|||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, DNS)
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
f := New()
|
||||
f.SetProxy(p)
|
||||
defer f.Close()
|
||||
|
@ -51,7 +52,7 @@ func TestForwardRefused(t *testing.T) {
|
|||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, DNS)
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
f := New()
|
||||
f.SetProxy(p)
|
||||
defer f.Close()
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -17,10 +19,10 @@ type HealthChecker interface {
|
|||
// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
|
||||
type dnsHc struct{ c *dns.Client }
|
||||
|
||||
// NewHealthChecker returns a new HealthChecker based on protocol.
|
||||
func NewHealthChecker(protocol int) HealthChecker {
|
||||
switch protocol {
|
||||
case DNS, TLS:
|
||||
// NewHealthChecker returns a new HealthChecker based on transport.
|
||||
func NewHealthChecker(trans string) HealthChecker {
|
||||
switch trans {
|
||||
case transport.DNS, transport.TLS:
|
||||
c := new(dns.Client)
|
||||
c.Net = "udp"
|
||||
c.ReadTimeout = 1 * time.Second
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -25,7 +26,7 @@ func TestHealth(t *testing.T) {
|
|||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, DNS)
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
f := New()
|
||||
f.SetProxy(p)
|
||||
defer f.Close()
|
||||
|
@ -65,7 +66,7 @@ func TestHealthTimeout(t *testing.T) {
|
|||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, DNS)
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
f := New()
|
||||
f.SetProxy(p)
|
||||
defer f.Close()
|
||||
|
@ -109,7 +110,7 @@ func TestHealthFailTwice(t *testing.T) {
|
|||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, DNS)
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
f := New()
|
||||
f.SetProxy(p)
|
||||
defer f.Close()
|
||||
|
@ -132,7 +133,7 @@ func TestHealthMaxFails(t *testing.T) {
|
|||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, DNS)
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
f := New()
|
||||
f.maxfails = 2
|
||||
f.SetProxy(p)
|
||||
|
@ -163,7 +164,7 @@ func TestHealthNoMaxFails(t *testing.T) {
|
|||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, DNS)
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
f := New()
|
||||
f.maxfails = 0
|
||||
f.SetProxy(p)
|
||||
|
|
|
@ -7,6 +7,7 @@ package forward
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
@ -81,7 +82,7 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M
|
|||
func NewLookup(addr []string) *Forward {
|
||||
f := New()
|
||||
for i := range addr {
|
||||
p := NewProxy(addr[i], DNS)
|
||||
p := NewProxy(addr[i], transport.DNS)
|
||||
f.SetProxy(p)
|
||||
}
|
||||
return f
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
|
@ -19,7 +20,7 @@ func TestLookup(t *testing.T) {
|
|||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, DNS)
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
f := New()
|
||||
f.SetProxy(p)
|
||||
defer f.Close()
|
||||
|
|
|
@ -15,8 +15,8 @@ type persistConn struct {
|
|||
used time.Time
|
||||
}
|
||||
|
||||
// transport hold the persistent cache.
|
||||
type transport struct {
|
||||
// Transport hold the persistent cache.
|
||||
type Transport struct {
|
||||
avgDialTime int64 // kind of average time of dial time
|
||||
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
||||
expire time.Duration // After this duration a connection is expired.
|
||||
|
@ -29,8 +29,8 @@ type transport struct {
|
|||
stop chan bool
|
||||
}
|
||||
|
||||
func newTransport(addr string) *transport {
|
||||
t := &transport{
|
||||
func newTransport(addr string) *Transport {
|
||||
t := &Transport{
|
||||
avgDialTime: int64(defaultDialTimeout / 2),
|
||||
conns: make(map[string][]*persistConn),
|
||||
expire: defaultExpire,
|
||||
|
@ -45,7 +45,7 @@ func newTransport(addr string) *transport {
|
|||
|
||||
// len returns the number of connection, used for metrics. Can only be safely
|
||||
// used inside connManager() because of data races.
|
||||
func (t *transport) len() int {
|
||||
func (t *Transport) len() int {
|
||||
l := 0
|
||||
for _, conns := range t.conns {
|
||||
l += len(conns)
|
||||
|
@ -54,7 +54,7 @@ func (t *transport) len() int {
|
|||
}
|
||||
|
||||
// connManagers manages the persistent connection cache for UDP and TCP.
|
||||
func (t *transport) connManager() {
|
||||
func (t *Transport) connManager() {
|
||||
ticker := time.NewTicker(t.expire)
|
||||
Wait:
|
||||
for {
|
||||
|
@ -115,7 +115,7 @@ func closeConns(conns []*persistConn) {
|
|||
}
|
||||
|
||||
// cleanup removes connections from cache.
|
||||
func (t *transport) cleanup(all bool) {
|
||||
func (t *Transport) cleanup(all bool) {
|
||||
staleTime := time.Now().Add(-t.expire)
|
||||
for proto, stack := range t.conns {
|
||||
if len(stack) == 0 {
|
||||
|
@ -144,19 +144,19 @@ func (t *transport) cleanup(all bool) {
|
|||
}
|
||||
|
||||
// Yield return the connection to transport for reuse.
|
||||
func (t *transport) Yield(c *dns.Conn) { t.yield <- c }
|
||||
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
|
||||
|
||||
// Start starts the transport's connection manager.
|
||||
func (t *transport) Start() { go t.connManager() }
|
||||
func (t *Transport) Start() { go t.connManager() }
|
||||
|
||||
// Stop stops the transport's connection manager.
|
||||
func (t *transport) Stop() { close(t.stop) }
|
||||
func (t *Transport) Stop() { close(t.stop) }
|
||||
|
||||
// SetExpire sets the connection expire time in transport.
|
||||
func (t *transport) SetExpire(expire time.Duration) { t.expire = expire }
|
||||
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
|
||||
|
||||
// SetTLSConfig sets the TLS config in transport.
|
||||
func (t *transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
||||
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
||||
|
||||
const (
|
||||
defaultExpire = 10 * time.Second
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
package forward
|
||||
|
||||
// Copied from coredns/core/dnsserver/address.go
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// protocol returns the protocol of the string s. The second string returns s
|
||||
// with the prefix chopped off.
|
||||
func protocol(s string) (int, string) {
|
||||
switch {
|
||||
case strings.HasPrefix(s, _tls+"://"):
|
||||
return TLS, s[len(_tls)+3:]
|
||||
case strings.HasPrefix(s, _dns+"://"):
|
||||
return DNS, s[len(_dns)+3:]
|
||||
}
|
||||
return DNS, s
|
||||
}
|
||||
|
||||
// Supported protocols.
|
||||
const (
|
||||
DNS = iota + 1
|
||||
TLS
|
||||
)
|
||||
|
||||
const (
|
||||
_dns = "dns"
|
||||
_tls = "tls"
|
||||
)
|
|
@ -18,7 +18,7 @@ type Proxy struct {
|
|||
|
||||
// Connection caching
|
||||
expire time.Duration
|
||||
transport *transport
|
||||
transport *Transport
|
||||
|
||||
// health checking
|
||||
probe *up.Probe
|
||||
|
@ -26,7 +26,7 @@ type Proxy struct {
|
|||
}
|
||||
|
||||
// NewProxy returns a new proxy.
|
||||
func NewProxy(addr string, protocol int) *Proxy {
|
||||
func NewProxy(addr, trans string) *Proxy {
|
||||
p := &Proxy{
|
||||
addr: addr,
|
||||
fails: 0,
|
||||
|
@ -34,7 +34,7 @@ func NewProxy(addr string, protocol int) *Proxy {
|
|||
transport: newTransport(addr),
|
||||
avgRtt: int64(maxTimeout / 2),
|
||||
}
|
||||
p.health = NewHealthChecker(protocol)
|
||||
p.health = NewHealthChecker(trans)
|
||||
runtime.SetFinalizer(p, (*Proxy).finalizer)
|
||||
return p
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
|
@ -26,7 +27,7 @@ func TestProxyClose(t *testing.T) {
|
|||
ctx := context.TODO()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
p := NewProxy(s.Addr, DNS)
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
p.start(hcInterval)
|
||||
|
||||
go func() { p.Connect(ctx, state, options{}) }()
|
||||
|
@ -95,7 +96,7 @@ func TestProxyTLSFail(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestProtocolSelection(t *testing.T) {
|
||||
p := NewProxy("bad_address", DNS)
|
||||
p := NewProxy("bad_address", transport.DNS)
|
||||
|
||||
stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
|
||||
stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)}
|
||||
|
|
|
@ -2,7 +2,6 @@ package forward
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
|
@ -11,6 +10,7 @@ import (
|
|||
"github.com/coredns/coredns/plugin/metrics"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnsutil"
|
||||
pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
|
@ -93,8 +93,6 @@ func parseForward(c *caddy.Controller) (*Forward, error) {
|
|||
func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
|
||||
f := New()
|
||||
|
||||
protocols := map[int]int{}
|
||||
|
||||
if !c.Args(&f.from) {
|
||||
return f, c.ArgErr()
|
||||
}
|
||||
|
@ -105,41 +103,17 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
|
|||
return f, c.ArgErr()
|
||||
}
|
||||
|
||||
// A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies.
|
||||
protocols = make(map[int]int)
|
||||
for i := range to {
|
||||
protocols[i], to[i] = protocol(to[i])
|
||||
}
|
||||
|
||||
// If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make
|
||||
// any sense anymore... For now: lets don't care.
|
||||
toHosts, err := dnsutil.ParseHostPortOrFile(to...)
|
||||
if err != nil {
|
||||
return f, err
|
||||
}
|
||||
|
||||
for i, h := range toHosts {
|
||||
// Double check the port, if e.g. is 53 and the transport is TLS make it 853.
|
||||
// This can be somewhat annoying because you *can't* have TLS on port 53 then.
|
||||
switch protocols[i] {
|
||||
case TLS:
|
||||
h1, p, err := net.SplitHostPort(h)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// This is more of a bug in dnsutil.ParseHostPortOrFile that defaults to
|
||||
// 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence
|
||||
// Fix the port number here, back to what the user intended.
|
||||
if p == "53" {
|
||||
h = net.JoinHostPort(h1, "853")
|
||||
}
|
||||
}
|
||||
|
||||
// We can't set tlsConfig here, because we haven't parsed it yet.
|
||||
// We set it below at the end of parseBlock, use nil now.
|
||||
p := NewProxy(h, protocols[i])
|
||||
transports := make([]string, len(toHosts))
|
||||
for i, host := range toHosts {
|
||||
trans, h := transport.Parse(host)
|
||||
p := NewProxy(h, trans)
|
||||
f.proxies = append(f.proxies, p)
|
||||
transports[i] = trans
|
||||
}
|
||||
|
||||
for c.NextBlock() {
|
||||
|
@ -153,7 +127,7 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
|
|||
}
|
||||
for i := range f.proxies {
|
||||
// Only set this for proxies that need it.
|
||||
if protocols[i] == TLS {
|
||||
if transports[i] == transport.TLS {
|
||||
f.proxies[i].SetTLSConfig(f.tlsConfig)
|
||||
}
|
||||
f.proxies[i].SetExpire(f.expire)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
|
@ -34,7 +35,7 @@ func TestLookupTruncated(t *testing.T) {
|
|||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, DNS)
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
f := New()
|
||||
f.SetProxy(p)
|
||||
defer f.Close()
|
||||
|
@ -88,9 +89,9 @@ func TestForwardTruncated(t *testing.T) {
|
|||
|
||||
f := New()
|
||||
|
||||
p1 := NewProxy(s.Addr, DNS)
|
||||
p1 := NewProxy(s.Addr, transport.DNS)
|
||||
f.SetProxy(p1)
|
||||
p2 := NewProxy(s.Addr, DNS)
|
||||
p2 := NewProxy(s.Addr, transport.DNS)
|
||||
f.SetProxy(p2)
|
||||
defer f.Close()
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -61,22 +63,10 @@ type (
|
|||
// Normalize will return the host portion of host, stripping
|
||||
// of any port or transport. The host will also be fully qualified and lowercased.
|
||||
func (h Host) Normalize() string {
|
||||
|
||||
s := string(h)
|
||||
_, s = transport.Parse(s)
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(s, TransportTLS+"://"):
|
||||
s = s[len(TransportTLS+"://"):]
|
||||
case strings.HasPrefix(s, TransportDNS+"://"):
|
||||
s = s[len(TransportDNS+"://"):]
|
||||
case strings.HasPrefix(s, TransportGRPC+"://"):
|
||||
s = s[len(TransportGRPC+"://"):]
|
||||
case strings.HasPrefix(s, TransportHTTPS+"://"):
|
||||
s = s[len(TransportHTTPS+"://"):]
|
||||
}
|
||||
|
||||
// The error can be ignore here, because this function is called after the corefile
|
||||
// has already been vetted.
|
||||
// The error can be ignore here, because this function is called after the corefile has already been vetted.
|
||||
host, _, _, _ := SplitHostPort(s)
|
||||
return Name(host).Normalize()
|
||||
}
|
||||
|
@ -138,11 +128,3 @@ func SplitHostPort(s string) (host, port string, ipnet *net.IPNet, err error) {
|
|||
}
|
||||
return host, port, n, nil
|
||||
}
|
||||
|
||||
// Duplicated from core/dnsserver/address.go !
|
||||
const (
|
||||
TransportDNS = "dns"
|
||||
TransportTLS = "tls"
|
||||
TransportGRPC = "grpc"
|
||||
TransportHTTPS = "https"
|
||||
)
|
||||
|
|
|
@ -5,15 +5,21 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ParseHostPortOrFile parses the strings in s, each string can either be a address,
|
||||
// address:port or a filename. The address part is checked and the filename case a
|
||||
// resolv.conf like file is parsed and the nameserver found are returned.
|
||||
// ParseHostPortOrFile parses the strings in s, each string can either be a
|
||||
// address, [scheme://]address:port or a filename. The address part is checked
|
||||
// and in case of filename a resolv.conf like file is (assumed) and parsed and
|
||||
// the nameservers found are returned.
|
||||
func ParseHostPortOrFile(s ...string) ([]string, error) {
|
||||
var servers []string
|
||||
for _, host := range s {
|
||||
for _, h := range s {
|
||||
|
||||
trans, host := transport.Parse(h)
|
||||
|
||||
addr, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
// Parse didn't work, it is not a addr:port combo
|
||||
|
@ -26,13 +32,23 @@ func ParseHostPortOrFile(s ...string) ([]string, error) {
|
|||
}
|
||||
return servers, fmt.Errorf("not an IP address or file: %q", host)
|
||||
}
|
||||
ss := net.JoinHostPort(host, "53")
|
||||
var ss string
|
||||
switch trans {
|
||||
case transport.DNS:
|
||||
ss = net.JoinHostPort(host, "53")
|
||||
case transport.TLS:
|
||||
ss = transport.TLS + "://" + net.JoinHostPort(host, transport.TLSPort)
|
||||
case transport.GRPC:
|
||||
ss = transport.GRPC + "://" + net.JoinHostPort(host, transport.GRPCPort)
|
||||
case transport.HTTPS:
|
||||
ss = transport.HTTPS + "://" + net.JoinHostPort(host, transport.HTTPSPort)
|
||||
}
|
||||
servers = append(servers, ss)
|
||||
continue
|
||||
}
|
||||
|
||||
if net.ParseIP(addr) == nil {
|
||||
// No an IP address.
|
||||
// Not an IP address.
|
||||
ss, err := tryFile(host)
|
||||
if err == nil {
|
||||
servers = append(servers, ss...)
|
||||
|
@ -40,7 +56,7 @@ func ParseHostPortOrFile(s ...string) ([]string, error) {
|
|||
}
|
||||
return servers, fmt.Errorf("not an IP address or file: %q", host)
|
||||
}
|
||||
servers = append(servers, host)
|
||||
servers = append(servers, h)
|
||||
}
|
||||
return servers, nil
|
||||
}
|
||||
|
|
49
plugin/pkg/transport/transport.go
Normal file
49
plugin/pkg/transport/transport.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package transport
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Parse returns the transport defined in s and a string where the
|
||||
// transport prefix is removed (if there was any). If no transport is defined
|
||||
// we default to TransportDNS
|
||||
func Parse(s string) (transport string, addr string) {
|
||||
switch {
|
||||
case strings.HasPrefix(s, TLS+"://"):
|
||||
s = s[len(TLS+"://"):]
|
||||
return TLS, s
|
||||
|
||||
case strings.HasPrefix(s, DNS+"://"):
|
||||
s = s[len(DNS+"://"):]
|
||||
return DNS, s
|
||||
|
||||
case strings.HasPrefix(s, GRPC+"://"):
|
||||
s = s[len(GRPC+"://"):]
|
||||
return GRPC, s
|
||||
|
||||
case strings.HasPrefix(s, HTTPS+"://"):
|
||||
s = s[len(HTTPS+"://"):]
|
||||
|
||||
return HTTPS, s
|
||||
}
|
||||
|
||||
return DNS, s
|
||||
}
|
||||
|
||||
// Supported transports.
|
||||
const (
|
||||
DNS = "dns"
|
||||
TLS = "tls"
|
||||
GRPC = "grpc"
|
||||
HTTPS = "https"
|
||||
)
|
||||
|
||||
// Port numbers for the various protocols
|
||||
const (
|
||||
// TLSPort is the default port for DNS-over-TLS.
|
||||
TLSPort = "853"
|
||||
// GRPCPort is the default port for DNS-over-gRPC.
|
||||
GRPCPort = "443"
|
||||
// HTTPSPort is the default port for DNS-over-HTTPS.
|
||||
HTTPSPort = "443"
|
||||
)
|
21
plugin/pkg/transport/transport_test.go
Normal file
21
plugin/pkg/transport/transport_test.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package transport
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"dns://.:53", DNS},
|
||||
{"2003::1/64.:53", DNS},
|
||||
{"grpc://example.org:1443 ", GRPC},
|
||||
{"tls://example.org ", TLS},
|
||||
{"https://example.org ", HTTPS},
|
||||
} {
|
||||
actual, _ := Parse(test.input)
|
||||
if actual != test.expected {
|
||||
t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue