package multinet import ( "bytes" "context" "fmt" "net" "net/netip" "sort" "sync" ) // Dialer contains the single most important method from the net.Dialer. type Dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } // Multidialer is like Dialer, but supports link state updates. type Multidialer interface { Dialer UpdateInterface(name string, addr netip.Addr, status bool) error } var ( _ Dialer = (*net.Dialer)(nil) _ Dialer = (*dialer)(nil) ) type dialer struct { // Protects subnets field (recursively). mtx sync.RWMutex subnets []Subnet // Default options for the net.Dialer. dialer net.Dialer // If true, allow to dial only configured subnets. restrict bool // Algorithm to which picks source address for each ip. balancer balancer // Hook used in tests, overrides `net.Dialer.DialContext()` testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) } // Subnet represents a single subnet, possibly routable from multiple interfaces. type Subnet struct { Mask netip.Prefix Interfaces []Source } // Source represents a single source IP belonging to a particular subnet. type Source struct { Name string LocalAddr *net.TCPAddr Down bool } // Config contains Multidialer configuration. type Config struct { // Routable subnets to prioritize in CIDR format. Subnets []string // If true, the only configurd subnets available through this dialer. // Otherwise, a failback to the net.DefaultDialer. Restrict bool // Dialer containes default options for the net.Dialer to use. // LocalAddr is overriden. Dialer net.Dialer // Balancer specifies algorithm used to pick source address. Balancer BalancerType } // NewDialer ... func NewDialer(c Config) (Multidialer, error) { ifaces, err := net.Interfaces() if err != nil { return nil, err } sort.Slice(ifaces, func(i, j int) bool { return ifaces[i].Name < ifaces[j].Name }) var sources []iface for i := range ifaces { info, err := processIface(ifaces[i]) if err != nil { return nil, err } sources = append(sources, info) } var d dialer for _, subnet := range c.Subnets { s, err := processSubnet(subnet, sources) if err != nil { return nil, err } d.subnets = append(d.subnets, s) } switch c.Balancer { case BalancerTypeNoop: d.balancer = &firstEnabled{d: &d} case BalancerTypeRoundRobin: d.balancer = &roundRobin{d: &d} default: return nil, fmt.Errorf("invalid balancer type: %s", c.Balancer) } d.restrict = c.Restrict return &d, nil } type iface struct { info net.Interface addrs []netip.Prefix } func processIface(info net.Interface) (iface, error) { ips, err := info.Addrs() if err != nil { return iface{}, err } var addrs []netip.Prefix for i := range ips { p, err := netip.ParsePrefix(ips[i].String()) if err != nil { return iface{}, err } addrs = append(addrs, p) } return iface{info: info, addrs: addrs}, nil } func processSubnet(subnet string, sources []iface) (Subnet, error) { s, err := netip.ParsePrefix(subnet) if err != nil { return Subnet{}, err } var ifs []Source for _, source := range sources { for i := range source.addrs { src := source.addrs[i].Addr() if s.Contains(src) { ifs = append(ifs, Source{ Name: source.info.Name, LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())}, }) } } } sort.Slice(ifs, func(i, j int) bool { if ifs[i].Name != ifs[j].Name { return ifs[i].Name < ifs[j].Name } return bytes.Compare(ifs[i].LocalAddr.IP, ifs[j].LocalAddr.IP) == -1 }) return Subnet{ Mask: s, Interfaces: ifs, }, nil } // DialContext implements the Dialer interface. // Hostnames for address are currently not supported. func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { addr, err := netip.ParseAddrPort(address) if err != nil { return nil, err } d.mtx.RLock() defer d.mtx.RUnlock() for i := range d.subnets { if d.subnets[i].Mask.Contains(addr.Addr()) { return d.balancer.DialContext(ctx, &d.subnets[i], network, address) } } if d.restrict { return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address) } return d.dialer.DialContext(ctx, network, address) } func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) { if h := d.testHookDialContext; h != nil { return h(nd, ctx, "tcp", address) } return nd.DialContext(ctx, "tcp", address) } // UpdateInterface implements the Multidialer interface. // Updating address on a specific interface is currently not supported. func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error { d.mtx.Lock() defer d.mtx.Unlock() for i := range d.subnets { for j := range d.subnets[i].Interfaces { matchIface := d.subnets[i].Interfaces[j].Name == iface if matchIface { d.subnets[i].Interfaces[j].Down = !up continue } a, _ := netip.AddrFromSlice(d.subnets[i].Interfaces[j].LocalAddr.IP) matchAddr := a.IsUnspecified() || addr == a if matchAddr { d.subnets[i].Interfaces[j].Down = !up } } } return nil }