[#11] dialer: Use source IPs intead of interfaces

The use of network interfaces does not cover cases where it is necessary
to use network interfaces to access different subnets.

Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
This commit is contained in:
Dmitrii Stepanov 2024-10-09 15:39:16 +03:00
parent f65c788d73
commit 4a46c8c008
Signed by: dstepanov-yadro
GPG key ID: 237AF1A763293BC0
16 changed files with 69 additions and 797 deletions

View file

@ -16,10 +16,31 @@ But sometimes you need to invent a bicycle.
## Usage ## Usage
```golang ```golang
import "git.frostfs.info/TrueCloudLab/multinet" import (
"context"
"net"
"net/netip"
"git.frostfs.info/TrueCloudLab/multinet"
)
d, err := multinet.NewDialer(Config{ d, err := multinet.NewDialer(Config{
Subnets: []string{"10.11.70.0/23", "192.168.123.0/24"}, Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("10.11.70.0/23"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("10.11.70.42"),
netip.MustParseAddr("10.11.71.42"),
},
},
{
Prefix: netip.MustParsePrefix("192.168.123.0/24"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("192.168.123.42"),
netip.MustParseAddr("192.168.123.142"),
},
},
},
Balancer: multinet.BalancerTypeRoundRobin, Balancer: multinet.BalancerTypeRoundRobin,
}) })
if err != nil { if err != nil {
@ -32,18 +53,3 @@ if err != nil {
} }
// do stuff // do stuff
``` ```
### Updating interface state
`Multidialer` exposes `UpdateInterface()` method for updating state of a single link.
`NetlinkWatcher` can wrap `Multidialer` type and perform all updates automatically.
TODO: describe needed capabilities here.
## Patch
To perform refactoring (use `multinet.Dial` instead of `net.Dial`) using [gopatch](https://github.com/uber-go/gopatch):
```bash
gopatch -p ./multinet.patch <project directory>
```

View file

@ -32,14 +32,10 @@ type roundRobin struct {
func (r *roundRobin) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) { func (r *roundRobin) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) {
next := int(r.i.Add(1)) next := int(r.i.Add(1))
for i := range s.Interfaces { for i := range s.SourceIPs {
ii := s.Interfaces[(i+next)%len(s.Interfaces)] ii := s.SourceIPs[(i+next)%len(s.SourceIPs)]
if ii.Down {
continue
}
dd := r.d.dialer dd := r.d.dialer
dd.LocalAddr = ii.LocalAddr dd.LocalAddr = &net.TCPAddr{IP: net.IP(ii.AsSlice())}
return r.d.dialContext(&dd, ctx, network, address) return r.d.dialContext(&dd, ctx, network, address)
} }
return nil, fmt.Errorf("(*roundRobin).DialContext: %w", errNoSuitableNodeFound) return nil, fmt.Errorf("(*roundRobin).DialContext: %w", errNoSuitableNodeFound)
@ -50,14 +46,10 @@ type firstEnabled struct {
} }
func (r *firstEnabled) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) { func (r *firstEnabled) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) {
for i := range s.Interfaces { for i := range s.SourceIPs {
ii := s.Interfaces[i%len(s.Interfaces)] ii := s.SourceIPs[i]
if ii.Down {
continue
}
dd := r.d.dialer dd := r.d.dialer
dd.LocalAddr = ii.LocalAddr dd.LocalAddr = &net.TCPAddr{IP: net.IP(ii.AsSlice())}
return r.d.dialContext(&dd, ctx, network, address) return r.d.dialContext(&dd, ctx, network, address)
} }
return nil, fmt.Errorf("(*firstEnabled).DialContext: %w", errNoSuitableNodeFound) return nil, fmt.Errorf("(*firstEnabled).DialContext: %w", errNoSuitableNodeFound)

View file

@ -1,31 +0,0 @@
package multinet
import (
"context"
"fmt"
"net"
)
var (
defaultDialer Multidialer
defaultDialerErr error
)
func init() {
var err error
defaultDialer, err = NewDialer(Config{
Balancer: BalancerTypeRoundRobin,
Subnets: []string{"0.0.0.0/0", "::/0"},
})
if err != nil {
defaultDialerErr = fmt.Errorf("failed to initialize default dialier: %w", err)
}
}
// Dial dials provided network and address using default dialer.
func Dial(network, address string) (net.Conn, error) {
if defaultDialerErr != nil {
return nil, defaultDialerErr
}
return defaultDialer.DialContext(context.Background(), network, address)
}

View file

@ -1,29 +0,0 @@
//go:build integration
package multinet
import (
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultDialer(t *testing.T) {
srv := startHTTP(t)
defer require.NoError(t, srv.Close())
conn, err := Dial("tcp", "localhost:8080")
require.NoError(t, err)
require.NoError(t, conn.Close())
}
func startHTTP(t *testing.T) *http.Server {
srv := &http.Server{Addr: ":8080"}
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "Test stub") })
go func() {
require.ErrorIs(t, srv.ListenAndServe(), http.ErrServerClosed)
}()
return srv
}

143
dialer.go
View file

@ -1,12 +1,10 @@
package multinet package multinet
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sort"
"sync" "sync"
"time" "time"
) )
@ -20,12 +18,6 @@ type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error) DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }
// Multidialer is like Dialer, but supports link state updates.
type Multidialer interface {
Dialer
UpdateInterface(name string, addr netip.Addr, status bool)
}
var ( var (
_ Dialer = (*net.Dialer)(nil) _ Dialer = (*net.Dialer)(nil)
_ Dialer = (*dialer)(nil) _ Dialer = (*dialer)(nil)
@ -49,23 +41,16 @@ type dialer struct {
fallbackDelay time.Duration fallbackDelay time.Duration
} }
// Subnet represents a single subnet, possibly routable from multiple interfaces. // Subnet represents a single subnet, possibly routable from multiple source IPs.
type Subnet struct { type Subnet struct {
Mask netip.Prefix Prefix netip.Prefix
Interfaces []Source SourceIPs []netip.Addr
}
// Source represents a single source IP belonging to a particular subnet.
type Source struct {
Name string
LocalAddr *net.TCPAddr
Down bool
} }
// Config contains Multidialer configuration. // Config contains Multidialer configuration.
type Config struct { type Config struct {
// Routable subnets to prioritize in CIDR format. // Routable subnets.
Subnets []string Subnets []Subnet
// If true, the only configurd subnets available through this dialer. // If true, the only configurd subnets available through this dialer.
// Otherwise, a failback to the net.DefaultDialer. // Otherwise, a failback to the net.DefaultDialer.
Restrict bool Restrict bool
@ -83,47 +68,15 @@ type Config struct {
// If zero, a default delay of 300ms is used. // If zero, a default delay of 300ms is used.
// A negative value disables Fast Fallback support. // A negative value disables Fast Fallback support.
FallbackDelay time.Duration FallbackDelay time.Duration
// InterfaceSource is custom `Interface`` source.
// If not specified, default implementation is used (`net.Interfaces()``).
InterfaceSource func() ([]Interface, error)
// DialContext is custom DialContext function. // DialContext is custom DialContext function.
// If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`). // If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`).
DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
} }
// NewDialer ... // NewDialer ...
func NewDialer(c Config) (Multidialer, error) { func NewDialer(c Config) (Dialer, error) {
var ifaces []Interface
var err error
if c.InterfaceSource != nil {
ifaces, err = c.InterfaceSource()
} else {
ifaces, err = systemInterfaces()
}
if err != nil {
return nil, err
}
sort.Slice(ifaces, func(i, j int) bool {
return ifaces[i].Name() < ifaces[j].Name()
})
var sources []iface
for i := range ifaces {
info, err := processIface(ifaces[i])
if err != nil {
return nil, err
}
sources = append(sources, info)
}
var d dialer var d dialer
for _, subnet := range c.Subnets { d.subnets = c.Subnets
s, err := processSubnet(subnet, sources)
if err != nil {
return nil, err
}
d.subnets = append(d.subnets, s)
}
switch c.Balancer { switch c.Balancer {
case BalancerTypeNoop: case BalancerTypeNoop:
@ -148,63 +101,6 @@ func NewDialer(c Config) (Multidialer, error) {
return &d, nil return &d, nil
} }
type iface struct {
name string
addrs []netip.Prefix
down bool
}
func processIface(info Interface) (iface, error) {
ips, err := info.Addrs()
if err != nil {
return iface{}, err
}
var addrs []netip.Prefix
for i := range ips {
p, err := netip.ParsePrefix(ips[i].String())
if err != nil {
return iface{}, err
}
addrs = append(addrs, p)
}
return iface{name: info.Name(), addrs: addrs, down: info.Down()}, nil
}
func processSubnet(subnet string, sources []iface) (Subnet, error) {
s, err := netip.ParsePrefix(subnet)
if err != nil {
return Subnet{}, err
}
var ifs []Source
for _, source := range sources {
for i := range source.addrs {
src := source.addrs[i].Addr()
if s.Contains(src) {
ifs = append(ifs, Source{
Name: source.name,
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
Down: source.down,
})
}
}
}
sort.Slice(ifs, func(i, j int) bool {
if ifs[i].Name != ifs[j].Name {
return ifs[i].Name < ifs[j].Name
}
return bytes.Compare(ifs[i].LocalAddr.IP, ifs[j].LocalAddr.IP) == -1
})
return Subnet{
Mask: s,
Interfaces: ifs,
}, nil
}
// DialContext implements the Dialer interface. // DialContext implements the Dialer interface.
// Hostnames for address are currently not supported. // Hostnames for address are currently not supported.
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
@ -371,7 +267,7 @@ func (d *dialer) dialAddr(ctx context.Context, network, address string, addr net
defer d.mtx.RUnlock() defer d.mtx.RUnlock()
for i := range d.subnets { for i := range d.subnets {
if d.subnets[i].Mask.Contains(addr.Addr()) { if d.subnets[i].Prefix.Contains(addr.Addr()) {
return d.balancer.DialContext(ctx, &d.subnets[i], network, address) return d.balancer.DialContext(ctx, &d.subnets[i], network, address)
} }
} }
@ -389,29 +285,6 @@ func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, addre
return nd.DialContext(ctx, network, address) return nd.DialContext(ctx, network, address)
} }
// UpdateInterface implements the Multidialer interface.
// Updating address on a specific interface is currently not supported.
func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) {
d.mtx.Lock()
defer d.mtx.Unlock()
for i := range d.subnets {
for j := range d.subnets[i].Interfaces {
matchIface := d.subnets[i].Interfaces[j].Name == iface
if matchIface {
d.subnets[i].Interfaces[j].Down = !up
continue
}
a, _ := netip.AddrFromSlice(d.subnets[i].Interfaces[j].LocalAddr.IP)
matchAddr := a.IsUnspecified() || addr == a
if matchAddr {
d.subnets[i].Interfaces[j].Down = !up
}
}
}
}
// splitByType divides an address list into two categories: // splitByType divides an address list into two categories:
// the first address, and any with same type, are returned as // the first address, and any with same type, are returned as
// primaries, while addresses with the opposite type are returned // primaries, while addresses with the opposite type are returned

View file

@ -3,6 +3,7 @@ package multinet
import ( import (
"context" "context"
"net" "net"
"net/netip"
"testing" "testing"
"time" "time"
@ -14,8 +15,15 @@ func TestHostnameResolveIPv4(t *testing.T) {
resolvedAddr := "10.11.12.180:8080" resolvedAddr := "10.11.12.180:8080"
resolved := false resolved := false
d, err := NewDialer(Config{ d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24"}, Subnets: []Subnet{
InterfaceSource: testInterfacesV4, {
Prefix: netip.MustParsePrefix("10.11.12.0/24"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("10.11.12.101"),
netip.MustParseAddr("10.11.12.102"),
},
},
},
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) { DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address { if resolvedAddr == address {
resolved = true resolved = true
@ -42,8 +50,15 @@ func TestHostnameResolveIPv6(t *testing.T) {
ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195") ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195")
resolved := false resolved := false
d, err := NewDialer(Config{ d, err := NewDialer(Config{
Subnets: []string{"2001:db8:85a3:8d3::/64"}, Subnets: []Subnet{
InterfaceSource: testInterfacesV6, {
Prefix: netip.MustParsePrefix("2001:db8:85a3:8d3::/64"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("2001:db8:85a3:8d3:1319:8a2e:370:7348"),
netip.MustParseAddr("2001:db8:85a3:8d3:1319:8a2e:370:8192"),
},
},
},
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) { DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address { if resolvedAddr == address {
resolved = true resolved = true
@ -65,70 +80,6 @@ func TestHostnameResolveIPv6(t *testing.T) {
require.True(t, resolved) require.True(t, resolved)
} }
func testInterfacesV4() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.101/24",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.102/24",
},
},
},
}, nil
}
func testInterfacesV6() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64",
},
},
},
}, nil
}
type testInterface struct {
name string
addrs []net.Addr
down bool
}
func (i *testInterface) Name() string { return i.name }
func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil }
func (i *testInterface) Down() bool { return i.down }
type testAddr struct {
network string
str string
}
func (a *testAddr) Network() string { return a.network }
func (a *testAddr) String() string { return a.str }
type testDnsConn struct { type testDnsConn struct {
wantName string wantName string
ipv4 []byte ipv4 []byte

View file

@ -1,166 +0,0 @@
//go:build integration
package multinet
import (
"net"
"net/netip"
"runtime"
"testing"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
)
func TestDialer(t *testing.T) {
runInNewNamespace(t, "2 interfaces with multiple routes in different subnets", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.10/23", "4.4.4.4/8"},
"testdev2": {"1.2.30.11/23", "4.4.4.5/8"},
})
// Do not use `t.Run` because everything should be executed in a single OS thread.
{ // Restrict to a single subnet.
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("1.2.30.0/23"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 10}}},
{Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}},
},
},
}, d.(*dialer).subnets)
}
{ // Restrict to two subnets.
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23", "4.0.0.0/8"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("1.2.30.0/23"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 10}}},
{Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}},
},
},
{
Mask: netip.MustParsePrefix("4.0.0.0/8"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{4, 4, 4, 4}}},
{Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{4, 4, 4, 5}}},
},
},
}, d.(*dialer).subnets)
}
})
runInNewNamespace(t, "4 interfaces, 2 for data, 2 internal", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"internal1": {"192.168.0.1/16"},
"internal2": {"192.168.0.2/16"},
"data1": {"10.11.12.101/24"},
"data2": {"10.11.12.102/24"},
})
d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24", "192.168.0.0/16"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("10.11.12.0/24"),
Interfaces: []Source{
{Name: "data1", LocalAddr: &net.TCPAddr{IP: net.IP{10, 11, 12, 101}}},
{Name: "data2", LocalAddr: &net.TCPAddr{IP: net.IP{10, 11, 12, 102}}},
},
},
{
Mask: netip.MustParsePrefix("192.168.0.0/16"),
Interfaces: []Source{
{Name: "internal1", LocalAddr: &net.TCPAddr{IP: net.IP{192, 168, 0, 1}}},
{Name: "internal2", LocalAddr: &net.TCPAddr{IP: net.IP{192, 168, 0, 2}}},
},
},
}, d.(*dialer).subnets)
})
runInNewNamespace(t, "with ipv6", func(t *testing.T, ns netns.NsHandle) {
addr1 := "2001:db8:85a3:8d3:1319:8a2e:370:7348/64"
addr2 := "2001:db8:85a3:8d3:1319:8a2e:370:8192/64"
setup(t, map[string][]string{
"testdev1": {addr1},
"testdev2": {addr2},
})
// Do not use `t.Run` because everything should be executed in a single OS thread.
{ // Restrict to a single subnet.
d, err := NewDialer(Config{
Subnets: []string{"2001:db8:85a3:8d3::/64"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("2001:db8:85a3:8d3::/64"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: mustParseIPv6(t, addr1)},
{Name: "testdev2", LocalAddr: mustParseIPv6(t, addr2)},
},
},
}, d.(*dialer).subnets)
}
})
}
func mustParseIPv6(t *testing.T, s string) *net.TCPAddr {
ip, _, err := net.ParseCIDR(s)
require.NoError(t, err)
return &net.TCPAddr{IP: ip}
}
func setup(t *testing.T, config map[string][]string) {
for name, ips := range config {
link := createLink(t, name)
for i := range ips {
ip, err := netlink.ParseIPNet(ips[i])
require.NoError(t, err)
require.NoError(t, netlink.AddrAdd(link, &netlink.Addr{IPNet: ip}))
}
}
}
func createLink(t *testing.T, name string) netlink.Link {
require.NoError(t, netlink.LinkAdd(&netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: name}}))
link, err := netlink.LinkByName(name)
require.NoError(t, err)
require.NoError(t, netlink.LinkSetUp(link))
return link
}
func runInNewNamespace(t *testing.T, name string, f func(t *testing.T, ns netns.NsHandle)) {
t.Run(name, func(t *testing.T) {
// To avoid messing with host network settings,
// we create a new names space and execute tests in it.
// Switching thread can move us to a different namespace, thus this line.
runtime.LockOSThread()
defer runtime.UnlockOSThread()
origns, err := netns.Get()
require.NoError(t, err)
defer origns.Close()
defer netns.Set(origns)
newns, err := netns.New()
require.NoError(t, err)
defer newns.Close()
f(t, newns)
})
}

View file

@ -2,17 +2,20 @@ package multinet
import ( import (
"context" "context"
"net" "net/netip"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestInterfacesDown(t *testing.T) { func TestNoSourceIPs(t *testing.T) {
t.Run("noop balancer", func(t *testing.T) { t.Run("noop balancer", func(t *testing.T) {
d, err := NewDialer(Config{ d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24"}, Subnets: []Subnet{
InterfaceSource: testDownInterfaces, {
Prefix: netip.MustParsePrefix("10.11.12.0/24"),
},
},
}) })
require.NoError(t, err) require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "10.11.12.254:8080") conn, err := d.DialContext(context.Background(), "tcp", "10.11.12.254:8080")
@ -21,8 +24,11 @@ func TestInterfacesDown(t *testing.T) {
}) })
t.Run("round robin balancer", func(t *testing.T) { t.Run("round robin balancer", func(t *testing.T) {
d, err := NewDialer(Config{ d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24"}, Subnets: []Subnet{
InterfaceSource: testDownInterfaces, {
Prefix: netip.MustParsePrefix("10.11.12.0/24"),
},
},
Balancer: BalancerTypeRoundRobin, Balancer: BalancerTypeRoundRobin,
}) })
require.NoError(t, err) require.NoError(t, err)
@ -31,28 +37,3 @@ func TestInterfacesDown(t *testing.T) {
require.Nil(t, conn) require.Nil(t, conn)
}) })
} }
func testDownInterfaces() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.101/24",
},
},
down: true,
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.102/24",
},
},
down: true,
},
}, nil
}

3
go.mod
View file

@ -4,10 +4,7 @@ go 1.20
require ( require (
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
github.com/vishvananda/netlink v1.1.0
github.com/vishvananda/netns v0.0.4
golang.org/x/net v0.17.0 golang.org/x/net v0.17.0
golang.org/x/sys v0.13.0
) )
require ( require (

BIN
go.sum

Binary file not shown.

View file

@ -1,67 +0,0 @@
package multinet
import (
"net/netip"
"sync"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
type NetlinkWatcher struct {
d Multidialer
linkUpdates chan netlink.LinkUpdate
addrUpdates chan netlink.AddrUpdate
done chan struct{}
wg sync.WaitGroup
}
func NewNetlinkWatcher(d Multidialer) *NetlinkWatcher {
return &NetlinkWatcher{
d: d,
addrUpdates: make(chan netlink.AddrUpdate, 1),
linkUpdates: make(chan netlink.LinkUpdate, 1),
done: make(chan struct{}),
}
}
func (w *NetlinkWatcher) Start() error {
if err := netlink.LinkSubscribe(w.linkUpdates, w.done); err != nil {
return err
}
if err := netlink.AddrSubscribe(w.addrUpdates, w.done); err != nil {
close(w.done)
return err
}
w.wg.Add(1)
go w.watch()
return nil
}
func (w *NetlinkWatcher) watch() {
defer w.wg.Done()
for {
select {
case <-w.done:
return
case update := <-w.addrUpdates:
// Wont work if an multiple interfaces share IP address.
// Should not happen in practice.
ip, ok := netip.AddrFromSlice(update.LinkAddress.IP)
if !ok {
continue
}
w.d.UpdateInterface("", ip, update.NewAddr)
case update := <-w.linkUpdates:
up := update.Flags&unix.IFF_UP != 0
w.d.UpdateInterface(update.Link.Attrs().Name, netip.Addr{}, up)
}
}
}
func (w *NetlinkWatcher) Stop() {
close(w.done)
w.wg.Wait()
}

View file

@ -1,157 +0,0 @@
//go:build integration
package multinet
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
)
func Test_NetlinkWatcher(t *testing.T) {
runInNewNamespace(t, "noop balancer, disable interface", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.11/23"},
"testdev2": {"1.2.30.12/23"},
})
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
w := NewNetlinkWatcher(d)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
link, err := netlink.LinkByName("testdev1")
require.NoError(t, err)
require.NoError(t, netlink.LinkSetDown(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr2)
require.NoError(t, netlink.LinkSetUp(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr1)
})
runInNewNamespace(t, "noop balancer, remove address", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.11/23"},
"testdev2": {"1.2.30.12/23"},
})
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
w := NewNetlinkWatcher(d)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
link, err := netlink.LinkByName("testdev1")
require.NoError(t, err)
ip, err := netlink.ParseIPNet("1.2.30.11/23")
require.NoError(t, err)
require.NoError(t, netlink.AddrDel(link, &netlink.Addr{IPNet: ip}))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr2)
require.NoError(t, netlink.AddrAdd(link, &netlink.Addr{IPNet: ip}))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr1)
})
runInNewNamespace(t, "round-robin balancer, disable interface", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.11/23"},
"testdev2": {"1.2.30.12/23"},
})
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
Balancer: BalancerTypeRoundRobin,
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
w := NewNetlinkWatcher(d)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr2)
link, err := netlink.LinkByName("testdev1")
require.NoError(t, err)
require.NoError(t, netlink.LinkSetDown(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr2)
require.NoError(t, netlink.LinkSetUp(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr2)
})
}
func checkDialAddr(t *testing.T, d Multidialer, ch chan net.Addr, expected net.Addr) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err := d.DialContext(ctx, "tcp", "1.2.30.42:12345")
require.NoError(t, err)
select {
case addr := <-ch:
require.Equal(t, expected, addr)
default:
require.Fail(t, "DialContext() was not called")
}
}

View file

@ -1,38 +0,0 @@
package multinet
import "net"
// Interface provides information about net.Interface.
type Interface interface {
Name() string
Addrs() ([]net.Addr, error)
Down() bool
}
type netInterface struct {
iface net.Interface
}
func (i *netInterface) Name() string {
return i.iface.Name
}
func (i *netInterface) Addrs() ([]net.Addr, error) {
return i.iface.Addrs()
}
func (i *netInterface) Down() bool {
return i.iface.Flags&net.FlagUp == 0
}
func systemInterfaces() ([]Interface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var result []Interface
for _, iface := range ifaces {
result = append(result, &netInterface{iface: iface})
}
return result, nil
}

View file

@ -1,7 +0,0 @@
@@
@@
+import "git.frostfs.info/TrueCloudLab/multinet"
-import "net"
-net.Dial(...)
+multinet.Dial(...)

19
testdata/patch_0.go vendored
View file

@ -1,19 +0,0 @@
package main
import (
"log"
"net"
)
const addr = "s01.frostfs.devenv:8080"
func main() {
_, err := net.Dial(getNetwork(), addr)
if err != nil {
log.Fatal(err)
}
}
func getNetwork() string {
return "tcp"
}

14
testdata/patch_1.go vendored
View file

@ -1,14 +0,0 @@
package main
import (
"log"
"net"
)
func main() {
ip := net.IPv4(192, 168, 0, 10)
_, err := net.Dial("tcp", ip.String()+":8080")
if err != nil {
log.Fatal(err)
}
}