//go:build integration package multinet import ( "context" "net" "testing" "time" "github.com/stretchr/testify/require" "github.com/vishvananda/netlink" "github.com/vishvananda/netns" ) func Test_NetlinkWatcher(t *testing.T) { runInNewNamespace(t, "noop balancer, disable interface", func(t *testing.T, ns netns.NsHandle) { setup(t, map[string][]string{ "testdev1": {"1.2.30.11/23"}, "testdev2": {"1.2.30.12/23"}, }) addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}} addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}} d, err := NewDialer(Config{ Subnets: []string{"1.2.30.0/23"}, }) require.NoError(t, err) w := NewNetlinkWatcher(d) require.NoError(t, w.Start()) t.Cleanup(w.Stop) result := make(chan net.Addr, 1) d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { result <- d.LocalAddr return nil, nil } checkDialAddr(t, d, result, addr1) checkDialAddr(t, d, result, addr1) link, err := netlink.LinkByName("testdev1") require.NoError(t, err) require.NoError(t, netlink.LinkSetDown(link)) time.Sleep(time.Second) checkDialAddr(t, d, result, addr2) checkDialAddr(t, d, result, addr2) require.NoError(t, netlink.LinkSetUp(link)) time.Sleep(time.Second) checkDialAddr(t, d, result, addr1) }) runInNewNamespace(t, "noop balancer, remove address", func(t *testing.T, ns netns.NsHandle) { setup(t, map[string][]string{ "testdev1": {"1.2.30.11/23"}, "testdev2": {"1.2.30.12/23"}, }) addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}} addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}} d, err := NewDialer(Config{ Subnets: []string{"1.2.30.0/23"}, }) require.NoError(t, err) w := NewNetlinkWatcher(d) require.NoError(t, w.Start()) t.Cleanup(w.Stop) result := make(chan net.Addr, 1) d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { result <- d.LocalAddr return nil, nil } checkDialAddr(t, d, result, addr1) checkDialAddr(t, d, result, addr1) link, err := netlink.LinkByName("testdev1") require.NoError(t, err) ip, err := netlink.ParseIPNet("1.2.30.11/23") require.NoError(t, err) require.NoError(t, netlink.AddrDel(link, &netlink.Addr{IPNet: ip})) time.Sleep(time.Second) checkDialAddr(t, d, result, addr2) checkDialAddr(t, d, result, addr2) require.NoError(t, netlink.AddrAdd(link, &netlink.Addr{IPNet: ip})) time.Sleep(time.Second) checkDialAddr(t, d, result, addr1) }) runInNewNamespace(t, "round-robin balancer, disable interface", func(t *testing.T, ns netns.NsHandle) { setup(t, map[string][]string{ "testdev1": {"1.2.30.11/23"}, "testdev2": {"1.2.30.12/23"}, }) addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}} addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}} d, err := NewDialer(Config{ Subnets: []string{"1.2.30.0/23"}, Balancer: BalancerTypeRoundRobin, }) require.NoError(t, err) w := NewNetlinkWatcher(d) require.NoError(t, w.Start()) t.Cleanup(w.Stop) result := make(chan net.Addr, 1) d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { result <- d.LocalAddr return nil, nil } checkDialAddr(t, d, result, addr2) checkDialAddr(t, d, result, addr1) checkDialAddr(t, d, result, addr2) link, err := netlink.LinkByName("testdev1") require.NoError(t, err) require.NoError(t, netlink.LinkSetDown(link)) time.Sleep(time.Second) checkDialAddr(t, d, result, addr2) checkDialAddr(t, d, result, addr2) require.NoError(t, netlink.LinkSetUp(link)) time.Sleep(time.Second) checkDialAddr(t, d, result, addr1) checkDialAddr(t, d, result, addr2) }) } func checkDialAddr(t *testing.T, d Multidialer, ch chan net.Addr, expected net.Addr) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() _, err := d.DialContext(ctx, "tcp", "1.2.30.42:12345") require.NoError(t, err) select { case addr := <-ch: require.Equal(t, expected, addr) default: require.Fail(t, "DialContext() was not called") } }