package multinet import ( "bytes" "context" "fmt" "net" "net/netip" "sort" "sync" "time" ) const ( defaultFallbackDelay = 300 * time.Millisecond ) // Dialer contains the single most important method from the net.Dialer. type Dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } // Multidialer is like Dialer, but supports link state updates. type Multidialer interface { Dialer UpdateInterface(name string, addr netip.Addr, status bool) error } var ( _ Dialer = (*net.Dialer)(nil) _ Dialer = (*dialer)(nil) ) type dialer struct { // Protects subnets field (recursively). mtx sync.RWMutex subnets []Subnet // Default options for the net.Dialer. dialer net.Dialer // If true, allow to dial only configured subnets. restrict bool // Algorithm to which picks source address for each ip. balancer balancer // Overrides `net.Dialer.DialContext()` if specified. customDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) // Hostname resolver resolver net.Resolver // See Config.FallbackDelay description. fallbackDelay time.Duration } // Subnet represents a single subnet, possibly routable from multiple interfaces. type Subnet struct { Mask netip.Prefix Interfaces []Source } // Source represents a single source IP belonging to a particular subnet. type Source struct { Name string LocalAddr *net.TCPAddr Down bool } // Config contains Multidialer configuration. type Config struct { // Routable subnets to prioritize in CIDR format. Subnets []string // If true, the only configurd subnets available through this dialer. // Otherwise, a failback to the net.DefaultDialer. Restrict bool // Dialer containes default options for the net.Dialer to use. // LocalAddr is overriden. Dialer net.Dialer // Balancer specifies algorithm used to pick source address. Balancer BalancerType // FallbackDelay specifies the length of time to wait before // spawning a RFC 6555 Fast Fallback connection. That is, this // is the amount of time to wait for IPv6 to succeed before // assuming that IPv6 is misconfigured and falling back to // IPv4. // // If zero, a default delay of 300ms is used. // A negative value disables Fast Fallback support. FallbackDelay time.Duration // InterfaceSource is custom `Interface`` source. // If not specified, default implementation is used (`net.Interfaces()``). InterfaceSource func() ([]Interface, error) // DialContext is custom DialContext function. // If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`). DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) } // NewDialer ... func NewDialer(c Config) (Multidialer, error) { var ifaces []Interface var err error if c.InterfaceSource != nil { ifaces, err = c.InterfaceSource() } else { ifaces, err = systemInterfaces() } if err != nil { return nil, err } sort.Slice(ifaces, func(i, j int) bool { return ifaces[i].Name() < ifaces[j].Name() }) var sources []iface for i := range ifaces { info, err := processIface(ifaces[i]) if err != nil { return nil, err } sources = append(sources, info) } var d dialer for _, subnet := range c.Subnets { s, err := processSubnet(subnet, sources) if err != nil { return nil, err } d.subnets = append(d.subnets, s) } switch c.Balancer { case BalancerTypeNoop: d.balancer = &firstEnabled{d: &d} case BalancerTypeRoundRobin: d.balancer = &roundRobin{d: &d} default: return nil, fmt.Errorf("invalid balancer type: %s", c.Balancer) } d.restrict = c.Restrict d.resolver.Dial = d.dialContextIP d.fallbackDelay = c.FallbackDelay if d.fallbackDelay == 0 { d.fallbackDelay = defaultFallbackDelay } d.dialer = c.Dialer if c.DialContext != nil { d.customDialContext = c.DialContext } return &d, nil } type iface struct { name string addrs []netip.Prefix } func processIface(info Interface) (iface, error) { ips, err := info.Addrs() if err != nil { return iface{}, err } var addrs []netip.Prefix for i := range ips { p, err := netip.ParsePrefix(ips[i].String()) if err != nil { return iface{}, err } addrs = append(addrs, p) } return iface{name: info.Name(), addrs: addrs}, nil } func processSubnet(subnet string, sources []iface) (Subnet, error) { s, err := netip.ParsePrefix(subnet) if err != nil { return Subnet{}, err } var ifs []Source for _, source := range sources { for i := range source.addrs { src := source.addrs[i].Addr() if s.Contains(src) { ifs = append(ifs, Source{ Name: source.name, LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())}, }) } } } sort.Slice(ifs, func(i, j int) bool { if ifs[i].Name != ifs[j].Name { return ifs[i].Name < ifs[j].Name } return bytes.Compare(ifs[i].LocalAddr.IP, ifs[j].LocalAddr.IP) == -1 }) return Subnet{ Mask: s, Interfaces: ifs, }, nil } // DialContext implements the Dialer interface. // Hostnames for address are currently not supported. func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { addr, err := netip.ParseAddrPort(address) if err != nil { //try resolve as hostname return d.dialContextHostname(ctx, network, address) } return d.dialAddr(ctx, network, address, addr) } func (d *dialer) dialContextIP(ctx context.Context, network, address string) (net.Conn, error) { addr, err := netip.ParseAddrPort(address) if err != nil { return nil, err } return d.dialAddr(ctx, network, address, addr) } func (d *dialer) dialContextHostname(ctx context.Context, network, address string) (net.Conn, error) { // https://github.com/golang/go/blob/release-branch.go1.21/src/net/dial.go#L488 addrPorts, err := d.resolveHostIPs(ctx, address, network) if err != nil { return nil, err } var primaries, fallbacks []netip.AddrPort if d.fallbackDelay >= 0 && network == "tcp" { primaries, fallbacks = splitByType(addrPorts) } else { primaries = addrPorts } return d.dialParallel(ctx, network, primaries, fallbacks) } func (d *dialer) dialSerial(ctx context.Context, network string, addrs []netip.AddrPort) (net.Conn, error) { var firstErr error for _, addr := range addrs { select { case <-ctx.Done(): return nil, ctx.Err() default: } conn, err := d.dialAddr(ctx, network, addr.String(), addr) if err == nil { return conn, nil } if firstErr == nil { firstErr = err } } return nil, fmt.Errorf("failed to connect any resolved address: %w", firstErr) } func (d *dialer) dialParallel(ctx context.Context, network string, primaries, fallbacks []netip.AddrPort) (net.Conn, error) { if len(fallbacks) == 0 { return d.dialSerial(ctx, network, primaries) } returned := make(chan struct{}) defer close(returned) type dialResult struct { net.Conn error primary bool done bool } results := make(chan dialResult) dialerFunc := func(ctx context.Context, primary bool) { addrs := primaries if !primary { addrs = fallbacks } c, err := d.dialSerial(ctx, network, addrs) select { case results <- dialResult{Conn: c, error: err, primary: primary, done: true}: case <-returned: if c != nil { c.Close() } } } var primary, fallback dialResult primaryCtx, primaryCancel := context.WithCancel(ctx) defer primaryCancel() go dialerFunc(primaryCtx, true) fallbackTimer := time.NewTimer(d.fallbackDelay) defer fallbackTimer.Stop() for { select { case <-fallbackTimer.C: fallbackCtx, fallbackCancel := context.WithCancel(ctx) defer fallbackCancel() go dialerFunc(fallbackCtx, false) case res := <-results: if res.error == nil { return res.Conn, nil } if res.primary { primary = res } else { fallback = res } if primary.done && fallback.done { return nil, primary.error } if res.primary && fallbackTimer.Stop() { // If we were able to stop the timer, that means it // was running (hadn't yet started the fallback), but // we just got an error on the primary path, so start // the fallback immediately (in 0 nanoseconds). fallbackTimer.Reset(0) } } } } func (d *dialer) resolveHostIPs(ctx context.Context, address string, network string) ([]netip.AddrPort, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, fmt.Errorf("invalid address format: %w", err) } portnum, err := d.resolver.LookupPort(ctx, network, port) if err != nil { return nil, fmt.Errorf("invalid port format: %w", err) } ips, err := d.resolver.LookupHost(ctx, host) if err != nil { return nil, fmt.Errorf("failed to resolve host address: %w", err) } if len(ips) == 0 { return nil, fmt.Errorf("failed to resolve address for [%s]%s", network, address) } var ipAddrs []netip.Addr for _, ip := range ips { ipAddr, err := netip.ParseAddr(ip) if err != nil { return nil, fmt.Errorf("failed to parse ip address '%s': %w", ip, err) } ipAddrs = append(ipAddrs, ipAddr) } var addrPorts []netip.AddrPort for _, ipAddr := range ipAddrs { addrPorts = append(addrPorts, netip.AddrPortFrom(ipAddr, uint16(portnum))) } return addrPorts, nil } func (d *dialer) dialAddr(ctx context.Context, network, address string, addr netip.AddrPort) (net.Conn, error) { d.mtx.RLock() defer d.mtx.RUnlock() for i := range d.subnets { if d.subnets[i].Mask.Contains(addr.Addr()) { return d.balancer.DialContext(ctx, &d.subnets[i], network, address) } } if d.restrict { return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address) } return d.dialContext(&d.dialer, ctx, network, address) } func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) { if h := d.customDialContext; h != nil { return h(nd, ctx, network, address) } return nd.DialContext(ctx, network, address) } // UpdateInterface implements the Multidialer interface. // Updating address on a specific interface is currently not supported. func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error { d.mtx.Lock() defer d.mtx.Unlock() for i := range d.subnets { for j := range d.subnets[i].Interfaces { matchIface := d.subnets[i].Interfaces[j].Name == iface if matchIface { d.subnets[i].Interfaces[j].Down = !up continue } a, _ := netip.AddrFromSlice(d.subnets[i].Interfaces[j].LocalAddr.IP) matchAddr := a.IsUnspecified() || addr == a if matchAddr { d.subnets[i].Interfaces[j].Down = !up } } } return nil } // splitByType divides an address list into two categories: // the first address, and any with same type, are returned as // primaries, while addresses with the opposite type are returned // as fallbacks. func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks []netip.AddrPort) { var primaryLabel bool for i, addr := range addrs { label := addr.Addr().Is4() if i == 0 || label == primaryLabel { primaryLabel = label primaries = append(primaries, addr) } else { fallbacks = append(fallbacks, addr) } } return }