diff --git a/balancer.go b/balancer.go index 6454aa9..31e0184 100644 --- a/balancer.go +++ b/balancer.go @@ -3,6 +3,7 @@ package multinet import ( "context" "errors" + "fmt" "net" "sync/atomic" ) @@ -18,6 +19,8 @@ const ( BalancerTypeRoundRobin BalancerType = "roundrobin" ) +var errNoSuitableNodeFound = errors.New("no suitale node found") + type balancer interface { DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) } @@ -39,7 +42,7 @@ func (r *roundRobin) DialContext(ctx context.Context, s *Subnet, network, addres dd.LocalAddr = ii.LocalAddr return r.d.dialContext(&dd, ctx, network, address) } - return nil, errors.New("(*roundRobin).DialContext: no suitale node found") + return nil, fmt.Errorf("(*roundRobin).DialContext: %w", errNoSuitableNodeFound) } type firstEnabled struct { @@ -57,5 +60,5 @@ func (r *firstEnabled) DialContext(ctx context.Context, s *Subnet, network, addr dd.LocalAddr = ii.LocalAddr return r.d.dialContext(&dd, ctx, network, address) } - return nil, errors.New("(*firstEnabled).DialContext: no suitale node found") + return nil, fmt.Errorf("(*firstEnabled).DialContext: %w", errNoSuitableNodeFound) } diff --git a/dialer.go b/dialer.go index 9b242ff..c0e1e96 100644 --- a/dialer.go +++ b/dialer.go @@ -151,6 +151,7 @@ func NewDialer(c Config) (Multidialer, error) { type iface struct { name string addrs []netip.Prefix + down bool } func processIface(info Interface) (iface, error) { @@ -168,7 +169,7 @@ func processIface(info Interface) (iface, error) { addrs = append(addrs, p) } - return iface{name: info.Name(), addrs: addrs}, nil + return iface{name: info.Name(), addrs: addrs, down: info.Down()}, nil } func processSubnet(subnet string, sources []iface) (Subnet, error) { @@ -185,6 +186,7 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) { ifs = append(ifs, Source{ Name: source.name, LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())}, + Down: source.down, }) } } diff --git a/dialer_hostname_test.go b/dialer_hostname_test.go index 1df305c..58d90a5 100644 --- a/dialer_hostname_test.go +++ b/dialer_hostname_test.go @@ -114,10 +114,12 @@ func testInterfacesV6() ([]Interface, error) { type testInterface struct { name string addrs []net.Addr + down bool } func (i *testInterface) Name() string { return i.name } func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil } +func (i *testInterface) Down() bool { return i.down } type testAddr struct { network string diff --git a/dialer_test.go b/dialer_test.go new file mode 100644 index 0000000..36ac5af --- /dev/null +++ b/dialer_test.go @@ -0,0 +1,58 @@ +package multinet + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestInterfacesDown(t *testing.T) { + t.Run("noop balancer", func(t *testing.T) { + d, err := NewDialer(Config{ + Subnets: []string{"10.11.12.0/24"}, + InterfaceSource: testDownInterfaces, + }) + require.NoError(t, err) + conn, err := d.DialContext(context.Background(), "tcp", "10.11.12.254:8080") + require.ErrorIs(t, err, errNoSuitableNodeFound) + require.Nil(t, conn) + }) + t.Run("round robin balancer", func(t *testing.T) { + d, err := NewDialer(Config{ + Subnets: []string{"10.11.12.0/24"}, + InterfaceSource: testDownInterfaces, + Balancer: BalancerTypeRoundRobin, + }) + require.NoError(t, err) + conn, err := d.DialContext(context.Background(), "tcp", "10.11.12.254:8080") + require.ErrorIs(t, err, errNoSuitableNodeFound) + require.Nil(t, conn) + }) +} + +func testDownInterfaces() ([]Interface, error) { + return []Interface{ + &testInterface{ + name: "data1", + addrs: []net.Addr{ + &testAddr{ + network: "tcp", + str: "10.11.12.101/24", + }, + }, + down: true, + }, + &testInterface{ + name: "data2", + addrs: []net.Addr{ + &testAddr{ + network: "tcp", + str: "10.11.12.102/24", + }, + }, + down: true, + }, + }, nil +} diff --git a/interface.go b/interface.go index 0d9677a..87b0fb5 100644 --- a/interface.go +++ b/interface.go @@ -6,6 +6,7 @@ import "net" type Interface interface { Name() string Addrs() ([]net.Addr, error) + Down() bool } type netInterface struct { @@ -20,6 +21,10 @@ func (i *netInterface) Addrs() ([]net.Addr, error) { return i.iface.Addrs() } +func (i *netInterface) Down() bool { + return i.iface.Flags&net.FlagUp == 0 +} + func systemInterfaces() ([]Interface, error) { ifaces, err := net.Interfaces() if err != nil {