Cleanup ParseHostOrFile (#2100)

Create plugin/pkg/transport that holds the transport related functions.
This needed to be a new pkg to prevent cyclic import errors.

This cleans up a bunch of duplicated code in core/dnsserver that also
tried to parse a transport (now all done in transport.Parse).

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben 2018-09-19 07:29:37 +01:00 committed by GitHub
parent 2f1223c36a
commit c349446a23
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 182 additions and 221 deletions

View file

@ -6,6 +6,7 @@ import (
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
@ -27,43 +28,13 @@ func (z zoneAddr) String() string {
return s
}
// Transport returns the protocol of the string s
func Transport(s string) string {
switch {
case strings.HasPrefix(s, TransportTLS+"://"):
return TransportTLS
case strings.HasPrefix(s, TransportDNS+"://"):
return TransportDNS
case strings.HasPrefix(s, TransportGRPC+"://"):
return TransportGRPC
case strings.HasPrefix(s, TransportHTTPS+"://"):
return TransportHTTPS
}
return TransportDNS
}
// normalizeZone parses an zone string into a structured format with separate
// host, and port portions, as well as the original input string.
func normalizeZone(str string) (zoneAddr, error) {
var err error
// Default to DNS if there isn't a transport protocol prefix.
trans := TransportDNS
switch {
case strings.HasPrefix(str, TransportTLS+"://"):
trans = TransportTLS
str = str[len(TransportTLS+"://"):]
case strings.HasPrefix(str, TransportDNS+"://"):
trans = TransportDNS
str = str[len(TransportDNS+"://"):]
case strings.HasPrefix(str, TransportGRPC+"://"):
trans = TransportGRPC
str = str[len(TransportGRPC+"://"):]
case strings.HasPrefix(str, TransportHTTPS+"://"):
trans = TransportHTTPS
str = str[len(TransportHTTPS+"://"):]
}
var trans string
trans, str = transport.Parse(str)
host, port, ipnet, err := plugin.SplitHostPort(str)
if err != nil {
@ -71,17 +42,15 @@ func normalizeZone(str string) (zoneAddr, error) {
}
if port == "" {
if trans == TransportDNS {
switch trans {
case transport.DNS:
port = Port
}
if trans == TransportTLS {
port = TLSPort
}
if trans == TransportGRPC {
port = GRPCPort
}
if trans == TransportHTTPS {
port = HTTPSPort
case transport.TLS:
port = transport.TLSPort
case transport.GRPC:
port = transport.GRPCPort
case transport.HTTPS:
port = transport.HTTPSPort
}
}
@ -103,14 +72,6 @@ func SplitProtocolHostPort(address string) (protocol string, ip string, port str
}
}
// Supported transports.
const (
TransportDNS = "dns"
TransportTLS = "tls"
TransportGRPC = "grpc"
TransportHTTPS = "https"
)
type zoneOverlap struct {
registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key
unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAdddr is registered by its original key

View file

@ -192,21 +192,3 @@ func TestOverlapAddressChecker(t *testing.T) {
}
}
}
func TestTransport(t *testing.T) {
for i, test := range []struct {
input string
expected string
}{
{"dns://.:53", TransportDNS},
{"2003::1/64.:53", TransportDNS},
{"grpc://example.org:1443 ", TransportGRPC},
{"tls://example.org ", TransportTLS},
{"https://example.org ", TransportHTTPS},
} {
actual := Transport(test.input)
if actual != test.expected {
t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual)
}
}
}

View file

@ -9,6 +9,7 @@ import (
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile"
@ -111,29 +112,29 @@ func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
var servers []caddy.Server
for addr, group := range groups {
// switch on addr
switch Transport(addr) {
case TransportDNS:
switch tr, _ := transport.Parse(addr); tr {
case transport.DNS:
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case TransportTLS:
case transport.TLS:
s, err := NewServerTLS(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case TransportGRPC:
case transport.GRPC:
s, err := NewServergRPC(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case TransportHTTPS:
case transport.HTTPS:
s, err := NewServerHTTPS(addr, group)
if err != nil {
return nil, err
@ -234,16 +235,8 @@ func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) {
return groups, nil
}
const (
// DefaultPort is the default port.
DefaultPort = "53"
// TLSPort is the default port for DNS-over-TLS.
TLSPort = "853"
// GRPCPort is the default port for DNS-over-gRPC.
GRPCPort = "443"
// HTTPSPort is the default port for DNS-over-HTTPS.
HTTPSPort = "443"
)
const DefaultPort = "53"
// These "soft defaults" are configurable by
// command line flags, etc.

View file

@ -15,6 +15,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/coredns/coredns/plugin/pkg/trace"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
@ -134,7 +135,7 @@ func (s *Server) ServePacket(p net.PacketConn) error {
// Listen implements caddy.TCPServer interface.
func (s *Server) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportDNS+"://"):])
l, err := net.Listen("tcp", s.Addr[len(transport.DNS+"://"):])
if err != nil {
return nil, err
}
@ -143,7 +144,7 @@ func (s *Server) Listen() (net.Listener, error) {
// ListenPacket implements caddy.UDPServer interface.
func (s *Server) ListenPacket() (net.PacketConn, error) {
p, err := net.ListenPacket("udp", s.Addr[len(TransportDNS+"://"):])
p, err := net.ListenPacket("udp", s.Addr[len(transport.DNS+"://"):])
if err != nil {
return nil, err
}

View file

@ -8,6 +8,7 @@ import (
"net"
"github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/pkg/watch"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
@ -73,7 +74,7 @@ func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServergRPC) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):])
l, err := net.Listen("tcp", s.Addr[len(transport.GRPC+"://"):])
if err != nil {
return nil, err
}
@ -90,7 +91,7 @@ func (s *ServergRPC) OnStartupComplete() {
return
}
out := startUpZones(TransportGRPC+"://", s.Addr, s.zones)
out := startUpZones(transport.GRPC+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}

View file

@ -12,6 +12,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/doh"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/plugin/pkg/transport"
)
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
@ -60,7 +61,7 @@ func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerHTTPS) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportHTTPS+"://"):])
l, err := net.Listen("tcp", s.Addr[len(transport.HTTPS+"://"):])
if err != nil {
return nil, err
}
@ -77,7 +78,7 @@ func (s *ServerHTTPS) OnStartupComplete() {
return
}
out := startUpZones(TransportHTTPS+"://", s.Addr, s.zones)
out := startUpZones(transport.HTTPS+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}

View file

@ -6,6 +6,8 @@ import (
"fmt"
"net"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
@ -55,7 +57,7 @@ func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerTLS) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportTLS+"://"):])
l, err := net.Listen("tcp", s.Addr[len(transport.TLS+"://"):])
if err != nil {
return nil, err
}
@ -72,7 +74,7 @@ func (s *ServerTLS) OnStartupComplete() {
return
}
out := startUpZones(TransportTLS+"://", s.Addr, s.zones)
out := startUpZones(transport.TLS+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}

View file

@ -35,16 +35,16 @@ func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight in
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
}
func (t *transport) dialTimeout() time.Duration {
func (t *Transport) dialTimeout() time.Duration {
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
}
func (t *transport) updateDialTimeout(newDialTime time.Duration) {
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
}
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) {
// If tls has been configured; use it.
if t.tlsConfig != nil {
proto = "tcp-tls"

View file

@ -4,6 +4,7 @@ import (
"testing"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
@ -19,7 +20,7 @@ func TestForward(t *testing.T) {
})
defer s.Close()
p := NewProxy(s.Addr, DNS)
p := NewProxy(s.Addr, transport.DNS)
f := New()
f.SetProxy(p)
defer f.Close()
@ -51,7 +52,7 @@ func TestForwardRefused(t *testing.T) {
})
defer s.Close()
p := NewProxy(s.Addr, DNS)
p := NewProxy(s.Addr, transport.DNS)
f := New()
f.SetProxy(p)
defer f.Close()

View file

@ -5,6 +5,8 @@ import (
"sync/atomic"
"time"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
@ -17,10 +19,10 @@ type HealthChecker interface {
// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
type dnsHc struct{ c *dns.Client }
// NewHealthChecker returns a new HealthChecker based on protocol.
func NewHealthChecker(protocol int) HealthChecker {
switch protocol {
case DNS, TLS:
// NewHealthChecker returns a new HealthChecker based on transport.
func NewHealthChecker(trans string) HealthChecker {
switch trans {
case transport.DNS, transport.TLS:
c := new(dns.Client)
c.Net = "udp"
c.ReadTimeout = 1 * time.Second

View file

@ -7,6 +7,7 @@ import (
"time"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
@ -25,7 +26,7 @@ func TestHealth(t *testing.T) {
})
defer s.Close()
p := NewProxy(s.Addr, DNS)
p := NewProxy(s.Addr, transport.DNS)
f := New()
f.SetProxy(p)
defer f.Close()
@ -65,7 +66,7 @@ func TestHealthTimeout(t *testing.T) {
})
defer s.Close()
p := NewProxy(s.Addr, DNS)
p := NewProxy(s.Addr, transport.DNS)
f := New()
f.SetProxy(p)
defer f.Close()
@ -109,7 +110,7 @@ func TestHealthFailTwice(t *testing.T) {
})
defer s.Close()
p := NewProxy(s.Addr, DNS)
p := NewProxy(s.Addr, transport.DNS)
f := New()
f.SetProxy(p)
defer f.Close()
@ -132,7 +133,7 @@ func TestHealthMaxFails(t *testing.T) {
})
defer s.Close()
p := NewProxy(s.Addr, DNS)
p := NewProxy(s.Addr, transport.DNS)
f := New()
f.maxfails = 2
f.SetProxy(p)
@ -163,7 +164,7 @@ func TestHealthNoMaxFails(t *testing.T) {
})
defer s.Close()
p := NewProxy(s.Addr, DNS)
p := NewProxy(s.Addr, transport.DNS)
f := New()
f.maxfails = 0
f.SetProxy(p)

View file

@ -7,6 +7,7 @@ package forward
import (
"context"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
@ -81,7 +82,7 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M
func NewLookup(addr []string) *Forward {
f := New()
for i := range addr {
p := NewProxy(addr[i], DNS)
p := NewProxy(addr[i], transport.DNS)
f.SetProxy(p)
}
return f

View file

@ -4,6 +4,7 @@ import (
"testing"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
@ -19,7 +20,7 @@ func TestLookup(t *testing.T) {
})
defer s.Close()
p := NewProxy(s.Addr, DNS)
p := NewProxy(s.Addr, transport.DNS)
f := New()
f.SetProxy(p)
defer f.Close()

View file

@ -15,8 +15,8 @@ type persistConn struct {
used time.Time
}
// transport hold the persistent cache.
type transport struct {
// Transport hold the persistent cache.
type Transport struct {
avgDialTime int64 // kind of average time of dial time
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
expire time.Duration // After this duration a connection is expired.
@ -29,8 +29,8 @@ type transport struct {
stop chan bool
}
func newTransport(addr string) *transport {
t := &transport{
func newTransport(addr string) *Transport {
t := &Transport{
avgDialTime: int64(defaultDialTimeout / 2),
conns: make(map[string][]*persistConn),
expire: defaultExpire,
@ -45,7 +45,7 @@ func newTransport(addr string) *transport {
// len returns the number of connection, used for metrics. Can only be safely
// used inside connManager() because of data races.
func (t *transport) len() int {
func (t *Transport) len() int {
l := 0
for _, conns := range t.conns {
l += len(conns)
@ -54,7 +54,7 @@ func (t *transport) len() int {
}
// connManagers manages the persistent connection cache for UDP and TCP.
func (t *transport) connManager() {
func (t *Transport) connManager() {
ticker := time.NewTicker(t.expire)
Wait:
for {
@ -115,7 +115,7 @@ func closeConns(conns []*persistConn) {
}
// cleanup removes connections from cache.
func (t *transport) cleanup(all bool) {
func (t *Transport) cleanup(all bool) {
staleTime := time.Now().Add(-t.expire)
for proto, stack := range t.conns {
if len(stack) == 0 {
@ -144,19 +144,19 @@ func (t *transport) cleanup(all bool) {
}
// Yield return the connection to transport for reuse.
func (t *transport) Yield(c *dns.Conn) { t.yield <- c }
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
// Start starts the transport's connection manager.
func (t *transport) Start() { go t.connManager() }
func (t *Transport) Start() { go t.connManager() }
// Stop stops the transport's connection manager.
func (t *transport) Stop() { close(t.stop) }
func (t *Transport) Stop() { close(t.stop) }
// SetExpire sets the connection expire time in transport.
func (t *transport) SetExpire(expire time.Duration) { t.expire = expire }
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
// SetTLSConfig sets the TLS config in transport.
func (t *transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
const (
defaultExpire = 10 * time.Second

View file

@ -1,30 +0,0 @@
package forward
// Copied from coredns/core/dnsserver/address.go
import (
"strings"
)
// protocol returns the protocol of the string s. The second string returns s
// with the prefix chopped off.
func protocol(s string) (int, string) {
switch {
case strings.HasPrefix(s, _tls+"://"):
return TLS, s[len(_tls)+3:]
case strings.HasPrefix(s, _dns+"://"):
return DNS, s[len(_dns)+3:]
}
return DNS, s
}
// Supported protocols.
const (
DNS = iota + 1
TLS
)
const (
_dns = "dns"
_tls = "tls"
)

View file

@ -18,7 +18,7 @@ type Proxy struct {
// Connection caching
expire time.Duration
transport *transport
transport *Transport
// health checking
probe *up.Probe
@ -26,7 +26,7 @@ type Proxy struct {
}
// NewProxy returns a new proxy.
func NewProxy(addr string, protocol int) *Proxy {
func NewProxy(addr, trans string) *Proxy {
p := &Proxy{
addr: addr,
fails: 0,
@ -34,7 +34,7 @@ func NewProxy(addr string, protocol int) *Proxy {
transport: newTransport(addr),
avgRtt: int64(maxTimeout / 2),
}
p.health = NewHealthChecker(protocol)
p.health = NewHealthChecker(trans)
runtime.SetFinalizer(p, (*Proxy).finalizer)
return p
}

View file

@ -5,6 +5,7 @@ import (
"testing"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
@ -26,7 +27,7 @@ func TestProxyClose(t *testing.T) {
ctx := context.TODO()
for i := 0; i < 100; i++ {
p := NewProxy(s.Addr, DNS)
p := NewProxy(s.Addr, transport.DNS)
p.start(hcInterval)
go func() { p.Connect(ctx, state, options{}) }()
@ -95,7 +96,7 @@ func TestProxyTLSFail(t *testing.T) {
}
func TestProtocolSelection(t *testing.T) {
p := NewProxy("bad_address", DNS)
p := NewProxy("bad_address", transport.DNS)
stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)}

View file

@ -2,7 +2,6 @@ package forward
import (
"fmt"
"net"
"strconv"
"time"
@ -11,6 +10,7 @@ import (
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile"
@ -93,8 +93,6 @@ func parseForward(c *caddy.Controller) (*Forward, error) {
func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
f := New()
protocols := map[int]int{}
if !c.Args(&f.from) {
return f, c.ArgErr()
}
@ -105,41 +103,17 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
return f, c.ArgErr()
}
// A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies.
protocols = make(map[int]int)
for i := range to {
protocols[i], to[i] = protocol(to[i])
}
// If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make
// any sense anymore... For now: lets don't care.
toHosts, err := dnsutil.ParseHostPortOrFile(to...)
if err != nil {
return f, err
}
for i, h := range toHosts {
// Double check the port, if e.g. is 53 and the transport is TLS make it 853.
// This can be somewhat annoying because you *can't* have TLS on port 53 then.
switch protocols[i] {
case TLS:
h1, p, err := net.SplitHostPort(h)
if err != nil {
break
}
// This is more of a bug in dnsutil.ParseHostPortOrFile that defaults to
// 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence
// Fix the port number here, back to what the user intended.
if p == "53" {
h = net.JoinHostPort(h1, "853")
}
}
// We can't set tlsConfig here, because we haven't parsed it yet.
// We set it below at the end of parseBlock, use nil now.
p := NewProxy(h, protocols[i])
transports := make([]string, len(toHosts))
for i, host := range toHosts {
trans, h := transport.Parse(host)
p := NewProxy(h, trans)
f.proxies = append(f.proxies, p)
transports[i] = trans
}
for c.NextBlock() {
@ -153,7 +127,7 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) {
}
for i := range f.proxies {
// Only set this for proxies that need it.
if protocols[i] == TLS {
if transports[i] == transport.TLS {
f.proxies[i].SetTLSConfig(f.tlsConfig)
}
f.proxies[i].SetExpire(f.expire)

View file

@ -5,6 +5,7 @@ import (
"testing"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
@ -34,7 +35,7 @@ func TestLookupTruncated(t *testing.T) {
})
defer s.Close()
p := NewProxy(s.Addr, DNS)
p := NewProxy(s.Addr, transport.DNS)
f := New()
f.SetProxy(p)
defer f.Close()
@ -88,9 +89,9 @@ func TestForwardTruncated(t *testing.T) {
f := New()
p1 := NewProxy(s.Addr, DNS)
p1 := NewProxy(s.Addr, transport.DNS)
f.SetProxy(p1)
p2 := NewProxy(s.Addr, DNS)
p2 := NewProxy(s.Addr, transport.DNS)
f.SetProxy(p2)
defer f.Close()

View file

@ -6,6 +6,8 @@ import (
"strconv"
"strings"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
@ -61,22 +63,10 @@ type (
// Normalize will return the host portion of host, stripping
// of any port or transport. The host will also be fully qualified and lowercased.
func (h Host) Normalize() string {
s := string(h)
_, s = transport.Parse(s)
switch {
case strings.HasPrefix(s, TransportTLS+"://"):
s = s[len(TransportTLS+"://"):]
case strings.HasPrefix(s, TransportDNS+"://"):
s = s[len(TransportDNS+"://"):]
case strings.HasPrefix(s, TransportGRPC+"://"):
s = s[len(TransportGRPC+"://"):]
case strings.HasPrefix(s, TransportHTTPS+"://"):
s = s[len(TransportHTTPS+"://"):]
}
// The error can be ignore here, because this function is called after the corefile
// has already been vetted.
// The error can be ignore here, because this function is called after the corefile has already been vetted.
host, _, _, _ := SplitHostPort(s)
return Name(host).Normalize()
}
@ -138,11 +128,3 @@ func SplitHostPort(s string) (host, port string, ipnet *net.IPNet, err error) {
}
return host, port, n, nil
}
// Duplicated from core/dnsserver/address.go !
const (
TransportDNS = "dns"
TransportTLS = "tls"
TransportGRPC = "grpc"
TransportHTTPS = "https"
)

View file

@ -5,15 +5,21 @@ import (
"net"
"os"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
// ParseHostPortOrFile parses the strings in s, each string can either be a address,
// address:port or a filename. The address part is checked and the filename case a
// resolv.conf like file is parsed and the nameserver found are returned.
// ParseHostPortOrFile parses the strings in s, each string can either be a
// address, [scheme://]address:port or a filename. The address part is checked
// and in case of filename a resolv.conf like file is (assumed) and parsed and
// the nameservers found are returned.
func ParseHostPortOrFile(s ...string) ([]string, error) {
var servers []string
for _, host := range s {
for _, h := range s {
trans, host := transport.Parse(h)
addr, _, err := net.SplitHostPort(host)
if err != nil {
// Parse didn't work, it is not a addr:port combo
@ -26,13 +32,23 @@ func ParseHostPortOrFile(s ...string) ([]string, error) {
}
return servers, fmt.Errorf("not an IP address or file: %q", host)
}
ss := net.JoinHostPort(host, "53")
var ss string
switch trans {
case transport.DNS:
ss = net.JoinHostPort(host, "53")
case transport.TLS:
ss = transport.TLS + "://" + net.JoinHostPort(host, transport.TLSPort)
case transport.GRPC:
ss = transport.GRPC + "://" + net.JoinHostPort(host, transport.GRPCPort)
case transport.HTTPS:
ss = transport.HTTPS + "://" + net.JoinHostPort(host, transport.HTTPSPort)
}
servers = append(servers, ss)
continue
}
if net.ParseIP(addr) == nil {
// No an IP address.
// Not an IP address.
ss, err := tryFile(host)
if err == nil {
servers = append(servers, ss...)
@ -40,7 +56,7 @@ func ParseHostPortOrFile(s ...string) ([]string, error) {
}
return servers, fmt.Errorf("not an IP address or file: %q", host)
}
servers = append(servers, host)
servers = append(servers, h)
}
return servers, nil
}

View file

@ -0,0 +1,49 @@
package transport
import (
"strings"
)
// Parse returns the transport defined in s and a string where the
// transport prefix is removed (if there was any). If no transport is defined
// we default to TransportDNS
func Parse(s string) (transport string, addr string) {
switch {
case strings.HasPrefix(s, TLS+"://"):
s = s[len(TLS+"://"):]
return TLS, s
case strings.HasPrefix(s, DNS+"://"):
s = s[len(DNS+"://"):]
return DNS, s
case strings.HasPrefix(s, GRPC+"://"):
s = s[len(GRPC+"://"):]
return GRPC, s
case strings.HasPrefix(s, HTTPS+"://"):
s = s[len(HTTPS+"://"):]
return HTTPS, s
}
return DNS, s
}
// Supported transports.
const (
DNS = "dns"
TLS = "tls"
GRPC = "grpc"
HTTPS = "https"
)
// Port numbers for the various protocols
const (
// TLSPort is the default port for DNS-over-TLS.
TLSPort = "853"
// GRPCPort is the default port for DNS-over-gRPC.
GRPCPort = "443"
// HTTPSPort is the default port for DNS-over-HTTPS.
HTTPSPort = "443"
)

View file

@ -0,0 +1,21 @@
package transport
import "testing"
func TestParse(t *testing.T) {
for i, test := range []struct {
input string
expected string
}{
{"dns://.:53", DNS},
{"2003::1/64.:53", DNS},
{"grpc://example.org:1443 ", GRPC},
{"tls://example.org ", TLS},
{"https://example.org ", HTTPS},
} {
actual, _ := Parse(test.input)
if actual != test.expected {
t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual)
}
}
}