diff --git a/dialer.go b/dialer.go index a0400d4..66025b6 100644 --- a/dialer.go +++ b/dialer.go @@ -8,6 +8,11 @@ import ( "net/netip" "sort" "sync" + "time" +) + +const ( + defaultFallbackDelay = 300 * time.Millisecond ) // Dialer contains the single most important method from the net.Dialer. @@ -38,6 +43,10 @@ type dialer struct { balancer balancer // Hook used in tests, overrides `net.Dialer.DialContext()` testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) + // Hostname resolver + resolver net.Resolver + // See Config.FallbackDelay description. + fallbackDelay time.Duration } // Subnet represents a single subnet, possibly routable from multiple interfaces. @@ -65,6 +74,15 @@ type Config struct { Dialer net.Dialer // Balancer specifies algorithm used to pick source address. Balancer BalancerType + // FallbackDelay specifies the length of time to wait before + // spawning a RFC 6555 Fast Fallback connection. That is, this + // is the amount of time to wait for IPv6 to succeed before + // assuming that IPv6 is misconfigured and falling back to + // IPv4. + // + // If zero, a default delay of 300ms is used. + // A negative value disables Fast Fallback support. + FallbackDelay time.Duration } // NewDialer ... @@ -105,6 +123,11 @@ func NewDialer(c Config) (Multidialer, error) { } d.restrict = c.Restrict + d.resolver.Dial = d.dialContextIP + d.fallbackDelay = c.FallbackDelay + if d.fallbackDelay == 0 { + d.fallbackDelay = defaultFallbackDelay + } return &d, nil } @@ -168,10 +191,164 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) { // Hostnames for address are currently not supported. func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { addr, err := netip.ParseAddrPort(address) + if err != nil { //try resolve as hostname + return d.dialContextHostname(ctx, network, address) + } + return d.dialAddr(ctx, network, address, addr) +} + +func (d *dialer) dialContextIP(ctx context.Context, network, address string) (net.Conn, error) { + addr, err := netip.ParseAddrPort(address) + if err != nil { + return nil, err + } + return d.dialAddr(ctx, network, address, addr) +} + +func (d *dialer) dialContextHostname(ctx context.Context, network, address string) (net.Conn, error) { + // https://github.com/golang/go/blob/release-branch.go1.21/src/net/dial.go#L488 + addrPorts, err := d.resolveHostIPs(ctx, address, network) if err != nil { return nil, err } + var primaries, fallbacks []netip.AddrPort + if d.fallbackDelay >= 0 && network == "tcp" { + primaries, fallbacks = splitByType(addrPorts) + } else { + primaries = addrPorts + } + + return d.dialParallel(ctx, network, primaries, fallbacks) +} + +func (d *dialer) dialSerial(ctx context.Context, network string, addrs []netip.AddrPort) (net.Conn, error) { + var firstErr error + for _, addr := range addrs { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + conn, err := d.dialAddr(ctx, network, addr.String(), addr) + if err == nil { + return conn, nil + } + if firstErr == nil { + firstErr = err + } + } + return nil, fmt.Errorf("failed to connect any resolved address: %w", firstErr) +} + +func (d *dialer) dialParallel(ctx context.Context, network string, primaries, fallbacks []netip.AddrPort) (net.Conn, error) { + if len(fallbacks) == 0 { + return d.dialSerial(ctx, network, primaries) + } + + returned := make(chan struct{}) + defer close(returned) + + type dialResult struct { + net.Conn + error + primary bool + done bool + } + results := make(chan dialResult) + + dialerFunc := func(ctx context.Context, primary bool) { + addrs := primaries + if !primary { + addrs = fallbacks + } + c, err := d.dialSerial(ctx, network, addrs) + select { + case results <- dialResult{Conn: c, error: err, primary: primary, done: true}: + case <-returned: + if c != nil { + c.Close() + } + } + } + + var primary, fallback dialResult + + primaryCtx, primaryCancel := context.WithCancel(ctx) + defer primaryCancel() + go dialerFunc(primaryCtx, true) + + fallbackTimer := time.NewTimer(d.fallbackDelay) + defer fallbackTimer.Stop() + + for { + select { + case <-fallbackTimer.C: + fallbackCtx, fallbackCancel := context.WithCancel(ctx) + defer fallbackCancel() + go dialerFunc(fallbackCtx, false) + + case res := <-results: + if res.error == nil { + return res.Conn, nil + } + if res.primary { + primary = res + } else { + fallback = res + } + if primary.done && fallback.done { + return nil, primary.error + } + if res.primary && fallbackTimer.Stop() { + // If we were able to stop the timer, that means it + // was running (hadn't yet started the fallback), but + // we just got an error on the primary path, so start + // the fallback immediately (in 0 nanoseconds). + fallbackTimer.Reset(0) + } + } + } +} + +func (d *dialer) resolveHostIPs(ctx context.Context, address string, network string) ([]netip.AddrPort, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("invalid address format: %w", err) + } + + portnum, err := d.resolver.LookupPort(ctx, network, port) + if err != nil { + return nil, fmt.Errorf("invalid port format: %w", err) + } + + ips, err := d.resolver.LookupHost(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to resolve host address: %w", err) + } + + if len(ips) == 0 { + return nil, fmt.Errorf("failed to resolve address for [%s]%s", network, address) + } + + var ipAddrs []netip.Addr + for _, ip := range ips { + ipAddr, err := netip.ParseAddr(ip) + if err != nil { + return nil, fmt.Errorf("failed to parse ip address '%s': %w", ip, err) + } + ipAddrs = append(ipAddrs, ipAddr) + } + + var addrPorts []netip.AddrPort + for _, ipAddr := range ipAddrs { + addrPorts = append(addrPorts, netip.AddrPortFrom(ipAddr, uint16(portnum))) + } + return addrPorts, nil +} + +func (d *dialer) dialAddr(ctx context.Context, network, address string, addr netip.AddrPort) (net.Conn, error) { d.mtx.RLock() defer d.mtx.RUnlock() @@ -217,3 +394,21 @@ func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error { } return nil } + +// splitByType divides an address list into two categories: +// the first address, and any with same type, are returned as +// primaries, while addresses with the opposite type are returned +// as fallbacks. +func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks []netip.AddrPort) { + var primaryLabel bool + for i, addr := range addrs { + label := addr.Addr().Is4() + if i == 0 || label == primaryLabel { + primaryLabel = label + primaries = append(primaries, addr) + } else { + fallbacks = append(fallbacks, addr) + } + } + return +}