package multinet import ( "context" "fmt" "net" "net/netip" "sync" "time" ) const ( defaultFallbackDelay = 300 * time.Millisecond ) // Dialer contains the single most important method from the net.Dialer. type Dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } var ( _ Dialer = (*net.Dialer)(nil) _ Dialer = (*dialer)(nil) ) type dialer struct { // Protects subnets field (recursively). mtx sync.RWMutex subnets []Subnet // Default options for the net.Dialer. dialer net.Dialer // If true, allow to dial only configured subnets. restrict bool // Algorithm to which picks source address for each ip. balancer balancer // Overrides `net.Dialer.DialContext()` if specified. customDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) // Hostname resolver resolver net.Resolver // See Config.FallbackDelay description. fallbackDelay time.Duration } // Subnet represents a single subnet, possibly routable from multiple source IPs. type Subnet struct { Prefix netip.Prefix SourceIPs []netip.Addr } // Config contains Multidialer configuration. type Config struct { // Routable subnets. Subnets []Subnet // If true, the only configurd subnets available through this dialer. // Otherwise, a failback to the net.DefaultDialer. Restrict bool // Dialer contains default options for the net.Dialer to use. // LocalAddr is overridden. Dialer net.Dialer // Balancer specifies algorithm used to pick source address. Balancer BalancerType // FallbackDelay specifies the length of time to wait before // spawning a RFC 6555 Fast Fallback connection. That is, this // is the amount of time to wait for IPv6 to succeed before // assuming that IPv6 is misconfigured and falling back to // IPv4. // // If zero, a default delay of 300ms is used. // A negative value disables Fast Fallback support. FallbackDelay time.Duration // DialContext is custom DialContext function. // If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`). DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) } // NewDialer ... func NewDialer(c Config) (Dialer, error) { var d dialer d.subnets = c.Subnets switch c.Balancer { case BalancerTypeNoop: d.balancer = &firstEnabled{d: &d} case BalancerTypeRoundRobin: d.balancer = &roundRobin{d: &d} default: return nil, fmt.Errorf("invalid balancer type: %s", c.Balancer) } d.restrict = c.Restrict d.resolver.Dial = d.dialContextIP d.fallbackDelay = c.FallbackDelay if d.fallbackDelay == 0 { d.fallbackDelay = defaultFallbackDelay } d.dialer = c.Dialer if c.DialContext != nil { d.customDialContext = c.DialContext } return &d, nil } // DialContext implements the Dialer interface. // Hostnames for address are currently not supported. func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { addr, err := netip.ParseAddrPort(address) if err != nil { // try resolve as hostname return d.dialContextHostname(ctx, network, address) } return d.dialAddr(ctx, network, address, addr) } func (d *dialer) dialContextIP(ctx context.Context, network, address string) (net.Conn, error) { addr, err := netip.ParseAddrPort(address) if err != nil { return nil, err } return d.dialAddr(ctx, network, address, addr) } func (d *dialer) dialContextHostname(ctx context.Context, network, address string) (net.Conn, error) { // https://github.com/golang/go/blob/release-branch.go1.21/src/net/dial.go#L488 addrPorts, err := d.resolveHostIPs(ctx, address, network) if err != nil { return nil, err } var primaries, fallbacks []netip.AddrPort if d.fallbackDelay >= 0 && network == "tcp" { primaries, fallbacks = splitByType(addrPorts) } else { primaries = addrPorts } return d.dialParallel(ctx, network, primaries, fallbacks) } func (d *dialer) dialSerial(ctx context.Context, network string, addrs []netip.AddrPort) (net.Conn, error) { var firstErr error for _, addr := range addrs { select { case <-ctx.Done(): return nil, ctx.Err() default: } conn, err := d.dialAddr(ctx, network, addr.String(), addr) if err == nil { return conn, nil } if firstErr == nil { firstErr = err } } return nil, fmt.Errorf("failed to connect any resolved address: %w", firstErr) } func (d *dialer) dialParallel(ctx context.Context, network string, primaries, fallbacks []netip.AddrPort) (net.Conn, error) { if len(fallbacks) == 0 { return d.dialSerial(ctx, network, primaries) } returned := make(chan struct{}) defer close(returned) type dialResult struct { net.Conn error primary bool done bool } results := make(chan dialResult) dialerFunc := func(ctx context.Context, primary bool) { addrs := primaries if !primary { addrs = fallbacks } c, err := d.dialSerial(ctx, network, addrs) select { case results <- dialResult{Conn: c, error: err, primary: primary, done: true}: case <-returned: if c != nil { c.Close() } } } var primary, fallback dialResult primaryCtx, primaryCancel := context.WithCancel(ctx) defer primaryCancel() go dialerFunc(primaryCtx, true) fallbackTimer := time.NewTimer(d.fallbackDelay) defer fallbackTimer.Stop() for { select { case <-fallbackTimer.C: fallbackCtx, fallbackCancel := context.WithCancel(ctx) defer fallbackCancel() go dialerFunc(fallbackCtx, false) case res := <-results: if res.error == nil { return res.Conn, nil } if res.primary { primary = res } else { fallback = res } if primary.done && fallback.done { return nil, primary.error } if res.primary && fallbackTimer.Stop() { // If we were able to stop the timer, that means it // was running (hadn't yet started the fallback), but // we just got an error on the primary path, so start // the fallback immediately (in 0 nanoseconds). fallbackTimer.Reset(0) } } } } func (d *dialer) resolveHostIPs(ctx context.Context, address string, network string) ([]netip.AddrPort, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, fmt.Errorf("invalid address format: %w", err) } portnum, err := d.resolver.LookupPort(ctx, network, port) if err != nil { return nil, fmt.Errorf("invalid port format: %w", err) } ips, err := d.resolver.LookupHost(ctx, host) if err != nil { return nil, fmt.Errorf("failed to resolve host address: %w", err) } if len(ips) == 0 { return nil, fmt.Errorf("failed to resolve address for [%s]%s", network, address) } var ipAddrs []netip.Addr for _, ip := range ips { ipAddr, err := netip.ParseAddr(ip) if err != nil { return nil, fmt.Errorf("failed to parse ip address '%s': %w", ip, err) } ipAddrs = append(ipAddrs, ipAddr) } var addrPorts []netip.AddrPort for _, ipAddr := range ipAddrs { addrPorts = append(addrPorts, netip.AddrPortFrom(ipAddr, uint16(portnum))) } return addrPorts, nil } func (d *dialer) dialAddr(ctx context.Context, network, address string, addr netip.AddrPort) (net.Conn, error) { d.mtx.RLock() defer d.mtx.RUnlock() for i := range d.subnets { if d.subnets[i].Prefix.Contains(addr.Addr()) { return d.balancer.DialContext(ctx, &d.subnets[i], network, address) } } if d.restrict { return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address) } return d.dialContext(ctx, &d.dialer, network, address) } func (d *dialer) dialContext(ctx context.Context, nd *net.Dialer, network, address string) (net.Conn, error) { if h := d.customDialContext; h != nil { return h(nd, ctx, network, address) } return nd.DialContext(ctx, network, address) } // splitByType divides an address list into two categories: // the first address, and any with same type, are returned as // primaries, while addresses with the opposite type are returned // as fallbacks. func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks []netip.AddrPort) { var primaryLabel bool for i, addr := range addrs { label := addr.Addr().Is4() if i == 0 || label == primaryLabel { primaryLabel = label primaries = append(primaries, addr) } else { fallbacks = append(fallbacks, addr) } } return }