commit b92dae991b2864853730d408c8ed56007c3bcf9a Author: Evgenii Stratonikov Date: Wed Aug 16 20:10:49 2023 +0300 Initial commit Signed-off-by: Evgenii Stratonikov diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4d77641 --- /dev/null +++ b/Makefile @@ -0,0 +1,3 @@ +test: + # TODO figure out needed capabilities + sudo go test -count=1 -v ./... diff --git a/README.md b/README.md new file mode 100644 index 0000000..b08a363 --- /dev/null +++ b/README.md @@ -0,0 +1,41 @@ +# Source-based routing in Golang + + +Consider this routing table: +``` +10.11.70.0/23 dev data0 proto kernel scope link src 10.11.70.42 +10.11.70.0/23 dev data1 proto kernel scope link src 10.11.71.42 +192.168.123.0/24 dev internal0 proto kernel scope link src 192.168.123.42 +192.168.123.0/24 dev internal1 proto kernel scope link src 192.168.123.142 +``` + +Simple `net.Dial` to either `10.11.70.42` or `10.11.71.42` will match the first subnet and be routed via data0. +This problems is usually solved by bonds. +But sometimes you need to invent a bicycle. + +## Usage + +```golang +import "git.frostfs.info/TrueCloudLab/multinet" + +d, err := multinet.NewDialer(Config{ + Subnets: []string{"10.11.70.0/23", "192.168.123.0/24"}, + Balancer: multinet.BalancerTypeRoundRobin, +}) +if err != nil { + // handle error +} + +conn, err := d.DialContext(ctx, "tcp", "10.11.70.42") +if err != nil { + // handle error +} +// do stuff +``` + +### Updating interface state + +`Multidialer` exposes `UpdateInterface()` method for updating state of a single link. +`NetlinkWatcher` can wrap `Multidialer` type and perform all updates automatically. + +TODO: describe needed capabilities here. \ No newline at end of file diff --git a/balancer.go b/balancer.go new file mode 100644 index 0000000..67245ea --- /dev/null +++ b/balancer.go @@ -0,0 +1,61 @@ +package multinet + +import ( + "context" + "errors" + "net" + "sync/atomic" +) + +// BalancerType reperents the algorithm which is used to pick source address. +type BalancerType string + +const ( + // BalancerTypeNoop picks first address for which link is up. + BalancerTypeNoop BalancerType = "" + // BalancerTypeNoop implements simple round-robin between up links. + // It is not fair in case some links are down. + BalancerTypeRoundRobin BalancerType = "roundrobin" +) + +type balancer interface { + DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) +} + +type roundRobin struct { + d *dialer + i atomic.Uint32 +} + +func (r *roundRobin) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) { + next := int(r.i.Add(1)) + for i := range s.Interfaces { + ii := s.Interfaces[(i+next)%len(s.Interfaces)] + if ii.Down { + continue + } + + dd := r.d.dialer + dd.LocalAddr = ii.LocalAddr + return r.d.dialContext(&dd, ctx, "tcp", address) + } + return nil, errors.New("(*roundRobin).DialContext: no suitale node found") +} + +type firstEnabled struct { + d *dialer +} + +func (r *firstEnabled) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) { + for i := range s.Interfaces { + ii := s.Interfaces[i%len(s.Interfaces)] + if ii.Down { + continue + } + + dd := r.d.dialer + dd.LocalAddr = ii.LocalAddr + return r.d.dialContext(&dd, ctx, "tcp", address) + } + return nil, errors.New("(*firstEnabled).DialContext: no suitale node found") +} diff --git a/dialer.go b/dialer.go new file mode 100644 index 0000000..a0400d4 --- /dev/null +++ b/dialer.go @@ -0,0 +1,219 @@ +package multinet + +import ( + "bytes" + "context" + "fmt" + "net" + "net/netip" + "sort" + "sync" +) + +// Dialer contains the single most important method from the net.Dialer. +type Dialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// Multidialer is like Dialer, but supports link state updates. +type Multidialer interface { + Dialer + UpdateInterface(name string, addr netip.Addr, status bool) error +} + +var ( + _ Dialer = (*net.Dialer)(nil) + _ Dialer = (*dialer)(nil) +) + +type dialer struct { + // Protects subnets field (recursively). + mtx sync.RWMutex + subnets []Subnet + // Default options for the net.Dialer. + dialer net.Dialer + // If true, allow to dial only configured subnets. + restrict bool + // Algorithm to which picks source address for each ip. + balancer balancer + // Hook used in tests, overrides `net.Dialer.DialContext()` + testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) +} + +// Subnet represents a single subnet, possibly routable from multiple interfaces. +type Subnet struct { + Mask netip.Prefix + Interfaces []Source +} + +// Source represents a single source IP belonging to a particular subnet. +type Source struct { + Name string + LocalAddr *net.TCPAddr + Down bool +} + +// Config contains Multidialer configuration. +type Config struct { + // Routable subnets to prioritize in CIDR format. + Subnets []string + // If true, the only configurd subnets available through this dialer. + // Otherwise, a failback to the net.DefaultDialer. + Restrict bool + // Dialer containes default options for the net.Dialer to use. + // LocalAddr is overriden. + Dialer net.Dialer + // Balancer specifies algorithm used to pick source address. + Balancer BalancerType +} + +// NewDialer ... +func NewDialer(c Config) (Multidialer, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + sort.Slice(ifaces, func(i, j int) bool { + return ifaces[i].Name < ifaces[j].Name + }) + + var sources []iface + for i := range ifaces { + info, err := processIface(ifaces[i]) + if err != nil { + return nil, err + } + sources = append(sources, info) + } + + var d dialer + for _, subnet := range c.Subnets { + s, err := processSubnet(subnet, sources) + if err != nil { + return nil, err + } + d.subnets = append(d.subnets, s) + } + + switch c.Balancer { + case BalancerTypeNoop: + d.balancer = &firstEnabled{d: &d} + case BalancerTypeRoundRobin: + d.balancer = &roundRobin{d: &d} + default: + return nil, fmt.Errorf("invalid balancer type: %s", c.Balancer) + } + + d.restrict = c.Restrict + + return &d, nil +} + +type iface struct { + info net.Interface + addrs []netip.Prefix +} + +func processIface(info net.Interface) (iface, error) { + ips, err := info.Addrs() + if err != nil { + return iface{}, err + } + + var addrs []netip.Prefix + for i := range ips { + p, err := netip.ParsePrefix(ips[i].String()) + if err != nil { + return iface{}, err + } + + addrs = append(addrs, p) + } + return iface{info: info, addrs: addrs}, nil +} + +func processSubnet(subnet string, sources []iface) (Subnet, error) { + s, err := netip.ParsePrefix(subnet) + if err != nil { + return Subnet{}, err + } + + var ifs []Source + for _, source := range sources { + for i := range source.addrs { + src := source.addrs[i].Addr() + if s.Contains(src) { + ifs = append(ifs, Source{ + Name: source.info.Name, + LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())}, + }) + } + } + } + + sort.Slice(ifs, func(i, j int) bool { + if ifs[i].Name != ifs[j].Name { + return ifs[i].Name < ifs[j].Name + } + return bytes.Compare(ifs[i].LocalAddr.IP, ifs[j].LocalAddr.IP) == -1 + }) + + return Subnet{ + Mask: s, + Interfaces: ifs, + }, nil +} + +// DialContext implements the Dialer interface. +// Hostnames for address are currently not supported. +func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + addr, err := netip.ParseAddrPort(address) + if err != nil { + return nil, err + } + + d.mtx.RLock() + defer d.mtx.RUnlock() + + for i := range d.subnets { + if d.subnets[i].Mask.Contains(addr.Addr()) { + return d.balancer.DialContext(ctx, &d.subnets[i], network, address) + } + } + + if d.restrict { + return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address) + } + return d.dialer.DialContext(ctx, network, address) +} + +func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) { + if h := d.testHookDialContext; h != nil { + return h(nd, ctx, "tcp", address) + } + return nd.DialContext(ctx, "tcp", address) +} + +// UpdateInterface implements the Multidialer interface. +// Updating address on a specific interface is currently not supported. +func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error { + d.mtx.Lock() + defer d.mtx.Unlock() + + for i := range d.subnets { + for j := range d.subnets[i].Interfaces { + matchIface := d.subnets[i].Interfaces[j].Name == iface + if matchIface { + d.subnets[i].Interfaces[j].Down = !up + continue + } + + a, _ := netip.AddrFromSlice(d.subnets[i].Interfaces[j].LocalAddr.IP) + matchAddr := a.IsUnspecified() || addr == a + if matchAddr { + d.subnets[i].Interfaces[j].Down = !up + } + } + } + return nil +} diff --git a/dialer_test.go b/dialer_test.go new file mode 100644 index 0000000..69e4581 --- /dev/null +++ b/dialer_test.go @@ -0,0 +1,164 @@ +package multinet + +import ( + "net" + "net/netip" + "runtime" + "testing" + + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" + "github.com/vishvananda/netns" +) + +func TestDialer(t *testing.T) { + runInNewNamespace(t, "2 interfaces with multiple routes in different subnets", func(t *testing.T, ns netns.NsHandle) { + setup(t, map[string][]string{ + "testdev1": {"1.2.30.10/23", "4.4.4.4/8"}, + "testdev2": {"1.2.30.11/23", "4.4.4.5/8"}, + }) + + // Do not use `t.Run` because everything should be executed in a single OS thread. + + { // Restrict to a single subnet. + d, err := NewDialer(Config{ + Subnets: []string{"1.2.30.0/23"}, + }) + require.NoError(t, err) + require.Equal(t, []Subnet{ + { + Mask: netip.MustParsePrefix("1.2.30.0/23"), + Interfaces: []Source{ + {Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 10}}}, + {Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}}, + }, + }, + }, d.(*dialer).subnets) + } + + { // Restrict to two subnets. + d, err := NewDialer(Config{ + Subnets: []string{"1.2.30.0/23", "4.0.0.0/8"}, + }) + require.NoError(t, err) + require.Equal(t, []Subnet{ + { + Mask: netip.MustParsePrefix("1.2.30.0/23"), + Interfaces: []Source{ + {Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 10}}}, + {Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}}, + }, + }, + { + Mask: netip.MustParsePrefix("4.0.0.0/8"), + Interfaces: []Source{ + {Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{4, 4, 4, 4}}}, + {Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{4, 4, 4, 5}}}, + }, + }, + }, d.(*dialer).subnets) + } + }) + runInNewNamespace(t, "4 interfaces, 2 for data, 2 internal", func(t *testing.T, ns netns.NsHandle) { + setup(t, map[string][]string{ + "internal1": {"192.168.0.1/16"}, + "internal2": {"192.168.0.2/16"}, + "data1": {"10.11.12.101/24"}, + "data2": {"10.11.12.102/24"}, + }) + + d, err := NewDialer(Config{ + Subnets: []string{"10.11.12.0/24", "192.168.0.0/16"}, + }) + require.NoError(t, err) + require.Equal(t, []Subnet{ + { + Mask: netip.MustParsePrefix("10.11.12.0/24"), + Interfaces: []Source{ + {Name: "data1", LocalAddr: &net.TCPAddr{IP: net.IP{10, 11, 12, 101}}}, + {Name: "data2", LocalAddr: &net.TCPAddr{IP: net.IP{10, 11, 12, 102}}}, + }, + }, + { + Mask: netip.MustParsePrefix("192.168.0.0/16"), + Interfaces: []Source{ + {Name: "internal1", LocalAddr: &net.TCPAddr{IP: net.IP{192, 168, 0, 1}}}, + {Name: "internal2", LocalAddr: &net.TCPAddr{IP: net.IP{192, 168, 0, 2}}}, + }, + }, + }, d.(*dialer).subnets) + }) + runInNewNamespace(t, "with ipv6", func(t *testing.T, ns netns.NsHandle) { + addr1 := "2001:db8:85a3:8d3:1319:8a2e:370:7348/64" + addr2 := "2001:db8:85a3:8d3:1319:8a2e:370:8192/64" + setup(t, map[string][]string{ + "testdev1": {addr1}, + "testdev2": {addr2}, + }) + + // Do not use `t.Run` because everything should be executed in a single OS thread. + + { // Restrict to a single subnet. + d, err := NewDialer(Config{ + Subnets: []string{"2001:db8:85a3:8d3::/64"}, + }) + require.NoError(t, err) + require.Equal(t, []Subnet{ + { + Mask: netip.MustParsePrefix("2001:db8:85a3:8d3::/64"), + Interfaces: []Source{ + {Name: "testdev1", LocalAddr: mustParseIPv6(t, addr1)}, + {Name: "testdev2", LocalAddr: mustParseIPv6(t, addr2)}, + }, + }, + }, d.(*dialer).subnets) + } + }) +} + +func mustParseIPv6(t *testing.T, s string) *net.TCPAddr { + ip, _, err := net.ParseCIDR(s) + require.NoError(t, err) + return &net.TCPAddr{IP: ip} +} + +func setup(t *testing.T, config map[string][]string) { + for name, ips := range config { + link := createLink(t, name) + for i := range ips { + ip, err := netlink.ParseIPNet(ips[i]) + require.NoError(t, err) + require.NoError(t, netlink.AddrAdd(link, &netlink.Addr{IPNet: ip})) + } + } +} + +func createLink(t *testing.T, name string) netlink.Link { + require.NoError(t, netlink.LinkAdd(&netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: name}})) + + link, err := netlink.LinkByName(name) + require.NoError(t, err) + require.NoError(t, netlink.LinkSetUp(link)) + return link +} + +func runInNewNamespace(t *testing.T, name string, f func(t *testing.T, ns netns.NsHandle)) { + t.Run(name, func(t *testing.T) { + // To avoid messing with host network settings, + // we create a new names space and execute tests in it. + // Switching thread can move us to a different namespace, thus this line. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + origns, err := netns.Get() + require.NoError(t, err) + defer origns.Close() + defer netns.Set(origns) + + newns, err := netns.New() + require.NoError(t, err) + defer newns.Close() + + f(t, newns) + }) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6ee639f --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module git.frostfs.info/TrueCloudLab/multinet + +go 1.21.0 + +require ( + github.com/stretchr/testify v1.8.4 + github.com/vishvananda/netlink v1.1.0 + github.com/vishvananda/netns v0.0.4 + golang.org/x/sys v0.2.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..83fb4e7 --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= +github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= +github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/health.go b/health.go new file mode 100644 index 0000000..da60f75 --- /dev/null +++ b/health.go @@ -0,0 +1,73 @@ +package multinet + +import ( + "net/netip" + "sync" + + "github.com/vishvananda/netlink" + "github.com/vishvananda/netns" + "golang.org/x/sys/unix" +) + +type NetlinkWatcher struct { + d Multidialer + linkUpdates chan netlink.LinkUpdate + addrUpdates chan netlink.AddrUpdate + done chan struct{} + wg sync.WaitGroup +} + +func NewNetlinkWatcher(d Multidialer) *NetlinkWatcher { + return &NetlinkWatcher{ + d: d, + addrUpdates: make(chan netlink.AddrUpdate, 1), + linkUpdates: make(chan netlink.LinkUpdate, 1), + done: make(chan struct{}), + } +} + +func (w *NetlinkWatcher) Start() error { + ns, err := netns.Get() + if err != nil { + return err + } + + if err := netlink.LinkSubscribe(w.linkUpdates, w.done); err != nil { + return err + } + if err := netlink.AddrSubscribe(w.addrUpdates, w.done); err != nil { + close(w.done) + return err + } + + w.wg.Add(1) + go w.watch(ns) + return nil +} + +func (w *NetlinkWatcher) watch(ns netns.NsHandle) { + defer w.wg.Done() + + for { + select { + case <-w.done: + return + case update := <-w.addrUpdates: + // Wont work if an multiple interfaces share IP address. + // Should not happen in practice. + ip, ok := netip.AddrFromSlice(update.LinkAddress.IP) + if !ok { + continue + } + w.d.UpdateInterface("", ip, update.NewAddr) + case update := <-w.linkUpdates: + up := update.Flags&unix.IFF_UP != 0 + w.d.UpdateInterface(update.Link.Attrs().Name, netip.Addr{}, up) + } + } +} + +func (w *NetlinkWatcher) Stop() { + close(w.done) + w.wg.Wait() +} diff --git a/health_test.go b/health_test.go new file mode 100644 index 0000000..333df2e --- /dev/null +++ b/health_test.go @@ -0,0 +1,158 @@ +package multinet + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" + "github.com/vishvananda/netns" +) + +func Test_NetlinkWatcher(t *testing.T) { + runInNewNamespace(t, "noop balancer, disable interface", func(t *testing.T, ns netns.NsHandle) { + setup(t, map[string][]string{ + "testdev1": {"1.2.30.11/23"}, + "testdev2": {"1.2.30.12/23"}, + }) + + addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}} + addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}} + + d, err := NewDialer(Config{ + Subnets: []string{"1.2.30.0/23"}, + }) + require.NoError(t, err) + + w := NewNetlinkWatcher(d) + require.NoError(t, w.Start()) + t.Cleanup(w.Stop) + + result := make(chan net.Addr, 1) + d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { + result <- d.LocalAddr + return nil, nil + } + + checkDialAddr(t, d, result, addr1) + checkDialAddr(t, d, result, addr1) + + link, err := netlink.LinkByName("testdev1") + require.NoError(t, err) + require.NoError(t, netlink.LinkSetDown(link)) + time.Sleep(time.Second) + + checkDialAddr(t, d, result, addr2) + checkDialAddr(t, d, result, addr2) + + require.NoError(t, netlink.LinkSetUp(link)) + time.Sleep(time.Second) + + checkDialAddr(t, d, result, addr1) + }) + + runInNewNamespace(t, "noop balancer, remove address", func(t *testing.T, ns netns.NsHandle) { + setup(t, map[string][]string{ + "testdev1": {"1.2.30.11/23"}, + "testdev2": {"1.2.30.12/23"}, + }) + + addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}} + addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}} + + d, err := NewDialer(Config{ + Subnets: []string{"1.2.30.0/23"}, + }) + require.NoError(t, err) + + w := NewNetlinkWatcher(d) + require.NoError(t, w.Start()) + t.Cleanup(w.Stop) + + result := make(chan net.Addr, 1) + d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { + result <- d.LocalAddr + return nil, nil + } + + checkDialAddr(t, d, result, addr1) + checkDialAddr(t, d, result, addr1) + + link, err := netlink.LinkByName("testdev1") + require.NoError(t, err) + + ip, err := netlink.ParseIPNet("1.2.30.11/23") + require.NoError(t, err) + require.NoError(t, netlink.AddrDel(link, &netlink.Addr{IPNet: ip})) + time.Sleep(time.Second) + + checkDialAddr(t, d, result, addr2) + checkDialAddr(t, d, result, addr2) + + require.NoError(t, netlink.AddrAdd(link, &netlink.Addr{IPNet: ip})) + time.Sleep(time.Second) + + checkDialAddr(t, d, result, addr1) + }) + + runInNewNamespace(t, "round-robin balancer, disable interface", func(t *testing.T, ns netns.NsHandle) { + setup(t, map[string][]string{ + "testdev1": {"1.2.30.11/23"}, + "testdev2": {"1.2.30.12/23"}, + }) + + addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}} + addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}} + + d, err := NewDialer(Config{ + Subnets: []string{"1.2.30.0/23"}, + Balancer: BalancerTypeRoundRobin, + }) + require.NoError(t, err) + + w := NewNetlinkWatcher(d) + require.NoError(t, w.Start()) + t.Cleanup(w.Stop) + + result := make(chan net.Addr, 1) + d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { + result <- d.LocalAddr + return nil, nil + } + + checkDialAddr(t, d, result, addr2) + checkDialAddr(t, d, result, addr1) + checkDialAddr(t, d, result, addr2) + + link, err := netlink.LinkByName("testdev1") + require.NoError(t, err) + require.NoError(t, netlink.LinkSetDown(link)) + time.Sleep(time.Second) + + checkDialAddr(t, d, result, addr2) + checkDialAddr(t, d, result, addr2) + + require.NoError(t, netlink.LinkSetUp(link)) + time.Sleep(time.Second) + + checkDialAddr(t, d, result, addr1) + checkDialAddr(t, d, result, addr2) + }) +} + +func checkDialAddr(t *testing.T, d Multidialer, ch chan net.Addr, expected net.Addr) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + _, err := d.DialContext(ctx, "tcp", "1.2.30.42:12345") + require.NoError(t, err) + + select { + case addr := <-ch: + require.Equal(t, expected, addr) + default: + require.Fail(t, "DialContext() was not called") + } +}