157 lines
4 KiB
Go
157 lines
4 KiB
Go
//go:build integration
|
|
|
|
package multinet
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/vishvananda/netlink"
|
|
"github.com/vishvananda/netns"
|
|
)
|
|
|
|
func Test_NetlinkWatcher(t *testing.T) {
|
|
runInNewNamespace(t, "noop balancer, disable interface", func(t *testing.T, ns netns.NsHandle) {
|
|
setup(t, map[string][]string{
|
|
"testdev1": {"1.2.30.11/23"},
|
|
"testdev2": {"1.2.30.12/23"},
|
|
})
|
|
|
|
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
|
|
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
|
|
|
|
result := make(chan net.Addr, 1)
|
|
d, err := NewDialer(Config{
|
|
Subnets: []string{"1.2.30.0/23"},
|
|
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
|
result <- d.LocalAddr
|
|
return nil, nil
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
w := NewNetlinkWatcher(d)
|
|
require.NoError(t, w.Start())
|
|
t.Cleanup(w.Stop)
|
|
|
|
checkDialAddr(t, d, result, addr1)
|
|
checkDialAddr(t, d, result, addr1)
|
|
|
|
link, err := netlink.LinkByName("testdev1")
|
|
require.NoError(t, err)
|
|
require.NoError(t, netlink.LinkSetDown(link))
|
|
time.Sleep(time.Second)
|
|
|
|
checkDialAddr(t, d, result, addr2)
|
|
checkDialAddr(t, d, result, addr2)
|
|
|
|
require.NoError(t, netlink.LinkSetUp(link))
|
|
time.Sleep(time.Second)
|
|
|
|
checkDialAddr(t, d, result, addr1)
|
|
})
|
|
|
|
runInNewNamespace(t, "noop balancer, remove address", func(t *testing.T, ns netns.NsHandle) {
|
|
setup(t, map[string][]string{
|
|
"testdev1": {"1.2.30.11/23"},
|
|
"testdev2": {"1.2.30.12/23"},
|
|
})
|
|
|
|
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
|
|
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
|
|
|
|
result := make(chan net.Addr, 1)
|
|
d, err := NewDialer(Config{
|
|
Subnets: []string{"1.2.30.0/23"},
|
|
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
|
result <- d.LocalAddr
|
|
return nil, nil
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
w := NewNetlinkWatcher(d)
|
|
require.NoError(t, w.Start())
|
|
t.Cleanup(w.Stop)
|
|
|
|
checkDialAddr(t, d, result, addr1)
|
|
checkDialAddr(t, d, result, addr1)
|
|
|
|
link, err := netlink.LinkByName("testdev1")
|
|
require.NoError(t, err)
|
|
|
|
ip, err := netlink.ParseIPNet("1.2.30.11/23")
|
|
require.NoError(t, err)
|
|
require.NoError(t, netlink.AddrDel(link, &netlink.Addr{IPNet: ip}))
|
|
time.Sleep(time.Second)
|
|
|
|
checkDialAddr(t, d, result, addr2)
|
|
checkDialAddr(t, d, result, addr2)
|
|
|
|
require.NoError(t, netlink.AddrAdd(link, &netlink.Addr{IPNet: ip}))
|
|
time.Sleep(time.Second)
|
|
|
|
checkDialAddr(t, d, result, addr1)
|
|
})
|
|
|
|
runInNewNamespace(t, "round-robin balancer, disable interface", func(t *testing.T, ns netns.NsHandle) {
|
|
setup(t, map[string][]string{
|
|
"testdev1": {"1.2.30.11/23"},
|
|
"testdev2": {"1.2.30.12/23"},
|
|
})
|
|
|
|
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
|
|
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
|
|
|
|
result := make(chan net.Addr, 1)
|
|
d, err := NewDialer(Config{
|
|
Subnets: []string{"1.2.30.0/23"},
|
|
Balancer: BalancerTypeRoundRobin,
|
|
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
|
result <- d.LocalAddr
|
|
return nil, nil
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
w := NewNetlinkWatcher(d)
|
|
require.NoError(t, w.Start())
|
|
t.Cleanup(w.Stop)
|
|
|
|
checkDialAddr(t, d, result, addr2)
|
|
checkDialAddr(t, d, result, addr1)
|
|
checkDialAddr(t, d, result, addr2)
|
|
|
|
link, err := netlink.LinkByName("testdev1")
|
|
require.NoError(t, err)
|
|
require.NoError(t, netlink.LinkSetDown(link))
|
|
time.Sleep(time.Second)
|
|
|
|
checkDialAddr(t, d, result, addr2)
|
|
checkDialAddr(t, d, result, addr2)
|
|
|
|
require.NoError(t, netlink.LinkSetUp(link))
|
|
time.Sleep(time.Second)
|
|
|
|
checkDialAddr(t, d, result, addr1)
|
|
checkDialAddr(t, d, result, addr2)
|
|
})
|
|
}
|
|
|
|
func checkDialAddr(t *testing.T, d Multidialer, ch chan net.Addr, expected net.Addr) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
|
|
_, err := d.DialContext(ctx, "tcp", "1.2.30.42:12345")
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case addr := <-ch:
|
|
require.Equal(t, expected, addr)
|
|
default:
|
|
require.Fail(t, "DialContext() was not called")
|
|
}
|
|
}
|