[#4] dialer: Support hostnames
Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
This commit is contained in:
parent
b92dae991b
commit
e0145b3a5f
1 changed files with 195 additions and 0 deletions
195
dialer.go
195
dialer.go
|
@ -8,6 +8,11 @@ import (
|
|||
"net/netip"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFallbackDelay = 300 * time.Millisecond
|
||||
)
|
||||
|
||||
// Dialer contains the single most important method from the net.Dialer.
|
||||
|
@ -38,6 +43,10 @@ type dialer struct {
|
|||
balancer balancer
|
||||
// Hook used in tests, overrides `net.Dialer.DialContext()`
|
||||
testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
|
||||
// Hostname resolver
|
||||
resolver net.Resolver
|
||||
// See Config.FallbackDelay description.
|
||||
fallbackDelay time.Duration
|
||||
}
|
||||
|
||||
// Subnet represents a single subnet, possibly routable from multiple interfaces.
|
||||
|
@ -65,6 +74,15 @@ type Config struct {
|
|||
Dialer net.Dialer
|
||||
// Balancer specifies algorithm used to pick source address.
|
||||
Balancer BalancerType
|
||||
// FallbackDelay specifies the length of time to wait before
|
||||
// spawning a RFC 6555 Fast Fallback connection. That is, this
|
||||
// is the amount of time to wait for IPv6 to succeed before
|
||||
// assuming that IPv6 is misconfigured and falling back to
|
||||
// IPv4.
|
||||
//
|
||||
// If zero, a default delay of 300ms is used.
|
||||
// A negative value disables Fast Fallback support.
|
||||
FallbackDelay time.Duration
|
||||
}
|
||||
|
||||
// NewDialer ...
|
||||
|
@ -105,6 +123,11 @@ func NewDialer(c Config) (Multidialer, error) {
|
|||
}
|
||||
|
||||
d.restrict = c.Restrict
|
||||
d.resolver.Dial = d.dialContextIP
|
||||
d.fallbackDelay = c.FallbackDelay
|
||||
if d.fallbackDelay == 0 {
|
||||
d.fallbackDelay = defaultFallbackDelay
|
||||
}
|
||||
|
||||
return &d, nil
|
||||
}
|
||||
|
@ -168,10 +191,164 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) {
|
|||
// Hostnames for address are currently not supported.
|
||||
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
addr, err := netip.ParseAddrPort(address)
|
||||
if err != nil { //try resolve as hostname
|
||||
return d.dialContextHostname(ctx, network, address)
|
||||
}
|
||||
return d.dialAddr(ctx, network, address, addr)
|
||||
}
|
||||
|
||||
func (d *dialer) dialContextIP(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
addr, err := netip.ParseAddrPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.dialAddr(ctx, network, address, addr)
|
||||
}
|
||||
|
||||
func (d *dialer) dialContextHostname(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// https://github.com/golang/go/blob/release-branch.go1.21/src/net/dial.go#L488
|
||||
addrPorts, err := d.resolveHostIPs(ctx, address, network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var primaries, fallbacks []netip.AddrPort
|
||||
if d.fallbackDelay >= 0 && network == "tcp" {
|
||||
primaries, fallbacks = splitByType(addrPorts)
|
||||
} else {
|
||||
primaries = addrPorts
|
||||
}
|
||||
|
||||
return d.dialParallel(ctx, network, primaries, fallbacks)
|
||||
}
|
||||
|
||||
func (d *dialer) dialSerial(ctx context.Context, network string, addrs []netip.AddrPort) (net.Conn, error) {
|
||||
var firstErr error
|
||||
for _, addr := range addrs {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := d.dialAddr(ctx, network, addr.String(), addr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to connect any resolved address: %w", firstErr)
|
||||
}
|
||||
|
||||
func (d *dialer) dialParallel(ctx context.Context, network string, primaries, fallbacks []netip.AddrPort) (net.Conn, error) {
|
||||
if len(fallbacks) == 0 {
|
||||
return d.dialSerial(ctx, network, primaries)
|
||||
}
|
||||
|
||||
returned := make(chan struct{})
|
||||
defer close(returned)
|
||||
|
||||
type dialResult struct {
|
||||
net.Conn
|
||||
error
|
||||
primary bool
|
||||
done bool
|
||||
}
|
||||
results := make(chan dialResult)
|
||||
|
||||
dialerFunc := func(ctx context.Context, primary bool) {
|
||||
addrs := primaries
|
||||
if !primary {
|
||||
addrs = fallbacks
|
||||
}
|
||||
c, err := d.dialSerial(ctx, network, addrs)
|
||||
select {
|
||||
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
|
||||
case <-returned:
|
||||
if c != nil {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var primary, fallback dialResult
|
||||
|
||||
primaryCtx, primaryCancel := context.WithCancel(ctx)
|
||||
defer primaryCancel()
|
||||
go dialerFunc(primaryCtx, true)
|
||||
|
||||
fallbackTimer := time.NewTimer(d.fallbackDelay)
|
||||
defer fallbackTimer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-fallbackTimer.C:
|
||||
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
|
||||
defer fallbackCancel()
|
||||
go dialerFunc(fallbackCtx, false)
|
||||
|
||||
case res := <-results:
|
||||
if res.error == nil {
|
||||
return res.Conn, nil
|
||||
}
|
||||
if res.primary {
|
||||
primary = res
|
||||
} else {
|
||||
fallback = res
|
||||
}
|
||||
if primary.done && fallback.done {
|
||||
return nil, primary.error
|
||||
}
|
||||
if res.primary && fallbackTimer.Stop() {
|
||||
// If we were able to stop the timer, that means it
|
||||
// was running (hadn't yet started the fallback), but
|
||||
// we just got an error on the primary path, so start
|
||||
// the fallback immediately (in 0 nanoseconds).
|
||||
fallbackTimer.Reset(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dialer) resolveHostIPs(ctx context.Context, address string, network string) ([]netip.AddrPort, error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid address format: %w", err)
|
||||
}
|
||||
|
||||
portnum, err := d.resolver.LookupPort(ctx, network, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port format: %w", err)
|
||||
}
|
||||
|
||||
ips, err := d.resolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve host address: %w", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("failed to resolve address for [%s]%s", network, address)
|
||||
}
|
||||
|
||||
var ipAddrs []netip.Addr
|
||||
for _, ip := range ips {
|
||||
ipAddr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ip address '%s': %w", ip, err)
|
||||
}
|
||||
ipAddrs = append(ipAddrs, ipAddr)
|
||||
}
|
||||
|
||||
var addrPorts []netip.AddrPort
|
||||
for _, ipAddr := range ipAddrs {
|
||||
addrPorts = append(addrPorts, netip.AddrPortFrom(ipAddr, uint16(portnum)))
|
||||
}
|
||||
return addrPorts, nil
|
||||
}
|
||||
|
||||
func (d *dialer) dialAddr(ctx context.Context, network, address string, addr netip.AddrPort) (net.Conn, error) {
|
||||
d.mtx.RLock()
|
||||
defer d.mtx.RUnlock()
|
||||
|
||||
|
@ -217,3 +394,21 @@ func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitByType divides an address list into two categories:
|
||||
// the first address, and any with same type, are returned as
|
||||
// primaries, while addresses with the opposite type are returned
|
||||
// as fallbacks.
|
||||
func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks []netip.AddrPort) {
|
||||
var primaryLabel bool
|
||||
for i, addr := range addrs {
|
||||
label := addr.Addr().Is4()
|
||||
if i == 0 || label == primaryLabel {
|
||||
primaryLabel = label
|
||||
primaries = append(primaries, addr)
|
||||
} else {
|
||||
fallbacks = append(fallbacks, addr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue