diff --git a/Makefile b/Makefile index d18847b..25f9279 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,6 @@ integration-test: # TODO figure out needed capabilities sudo go test -count=1 -v ./... -tags=integration + +test: + go test -count=1 -v ./... diff --git a/dialer.go b/dialer.go index 086b9b8..a78e3b5 100644 --- a/dialer.go +++ b/dialer.go @@ -41,8 +41,8 @@ type dialer struct { restrict bool // Algorithm to which picks source address for each ip. balancer balancer - // Hook used in tests, overrides `net.Dialer.DialContext()` - testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) + // Overrides `net.Dialer.DialContext()` if specified. + customDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) // Hostname resolver resolver net.Resolver // See Config.FallbackDelay description. @@ -83,16 +83,28 @@ type Config struct { // If zero, a default delay of 300ms is used. // A negative value disables Fast Fallback support. FallbackDelay time.Duration + // InterfaceSource is custom `Interface`` source. + // If not specified, default implementation is used (`net.Interfaces()``). + InterfaceSource func() ([]Interface, error) + // DialContext is custom DialContext function. + // If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`). + DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) } // NewDialer ... func NewDialer(c Config) (Multidialer, error) { - ifaces, err := net.Interfaces() + var ifaces []Interface + var err error + if c.InterfaceSource != nil { + ifaces, err = c.InterfaceSource() + } else { + ifaces, err = systemInterfaces() + } if err != nil { return nil, err } sort.Slice(ifaces, func(i, j int) bool { - return ifaces[i].Name < ifaces[j].Name + return ifaces[i].Name() < ifaces[j].Name() }) var sources []iface @@ -128,16 +140,20 @@ func NewDialer(c Config) (Multidialer, error) { if d.fallbackDelay == 0 { d.fallbackDelay = defaultFallbackDelay } + d.dialer = c.Dialer + if c.DialContext != nil { + d.customDialContext = c.DialContext + } return &d, nil } type iface struct { - info net.Interface + name string addrs []netip.Prefix } -func processIface(info net.Interface) (iface, error) { +func processIface(info Interface) (iface, error) { ips, err := info.Addrs() if err != nil { return iface{}, err @@ -152,7 +168,7 @@ func processIface(info net.Interface) (iface, error) { addrs = append(addrs, p) } - return iface{info: info, addrs: addrs}, nil + return iface{name: info.Name(), addrs: addrs}, nil } func processSubnet(subnet string, sources []iface) (Subnet, error) { @@ -167,7 +183,7 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) { src := source.addrs[i].Addr() if s.Contains(src) { ifs = append(ifs, Source{ - Name: source.info.Name, + Name: source.name, LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())}, }) } @@ -361,11 +377,11 @@ func (d *dialer) dialAddr(ctx context.Context, network, address string, addr net if d.restrict { return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address) } - return d.dialer.DialContext(ctx, network, address) + return d.dialContext(&d.dialer, ctx, network, address) } func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) { - if h := d.testHookDialContext; h != nil { + if h := d.customDialContext; h != nil { return h(nd, ctx, network, address) } return nd.DialContext(ctx, network, address) diff --git a/dialer_hostname_test.go b/dialer_hostname_test.go new file mode 100644 index 0000000..1df305c --- /dev/null +++ b/dialer_hostname_test.go @@ -0,0 +1,308 @@ +package multinet + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/net/dns/dnsmessage" +) + +func TestHostnameResolveIPv4(t *testing.T) { + resolvedAddr := "10.11.12.180:8080" + resolved := false + d, err := NewDialer(Config{ + Subnets: []string{"10.11.12.0/24"}, + InterfaceSource: testInterfacesV4, + DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) { + if resolvedAddr == address { + resolved = true + return &testConnStub{}, nil + } + if network == "udp" { + return &testDnsConn{ + wantName: "s01.storage.devnev.", + ipv4: []byte{10, 11, 12, 180}, + }, nil + } + panic("unexpected call") + }, + }) + require.NoError(t, err) + conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080") + require.NoError(t, err) + require.NoError(t, conn.Close()) + require.True(t, resolved) +} + +func TestHostnameResolveIPv6(t *testing.T) { + resolvedAddr := "[2001:db8:85a3:8d3:1319:8a2e:370:8195]:8080" + ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195") + resolved := false + d, err := NewDialer(Config{ + Subnets: []string{"2001:db8:85a3:8d3::/64"}, + InterfaceSource: testInterfacesV6, + DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) { + if resolvedAddr == address { + resolved = true + return &testConnStub{}, nil + } + if network == "udp" { + return &testDnsConn{ + wantName: "s01.storage.devnev.", + ipv6: ipv6, + }, nil + } + panic("unexpected call") + }, + }) + require.NoError(t, err) + conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080") + require.NoError(t, err) + require.NoError(t, conn.Close()) + require.True(t, resolved) +} + +func testInterfacesV4() ([]Interface, error) { + return []Interface{ + &testInterface{ + name: "data1", + addrs: []net.Addr{ + &testAddr{ + network: "tcp", + str: "10.11.12.101/24", + }, + }, + }, + &testInterface{ + name: "data2", + addrs: []net.Addr{ + &testAddr{ + network: "tcp", + str: "10.11.12.102/24", + }, + }, + }, + }, nil +} + +func testInterfacesV6() ([]Interface, error) { + return []Interface{ + &testInterface{ + name: "data1", + addrs: []net.Addr{ + &testAddr{ + network: "tcp", + str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64", + }, + }, + }, + &testInterface{ + name: "data2", + addrs: []net.Addr{ + &testAddr{ + network: "tcp", + str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64", + }, + }, + }, + }, nil +} + +type testInterface struct { + name string + addrs []net.Addr +} + +func (i *testInterface) Name() string { return i.name } +func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil } + +type testAddr struct { + network string + str string +} + +func (a *testAddr) Network() string { return a.network } +func (a *testAddr) String() string { return a.str } + +type testDnsConn struct { + wantName string + ipv4 []byte + ipv6 []byte + reqID uint16 + requested bool + secondRead bool + questions []dnsmessage.Question +} + +// Close implements net.Conn. +func (*testDnsConn) Close() error { + return nil +} + +// LocalAddr implements net.Conn. +func (*testDnsConn) LocalAddr() net.Addr { + panic("unimplemented") +} + +// Read implements net.Conn. +func (c *testDnsConn) Read(b []byte) (n int, err error) { + if c.secondRead { + data, err := c.data() + if err != nil { + return 0, err + } + copy(b, data) + return len(data), nil + } + + data, err := c.data() + if err != nil { + return 0, err + } + l := len(data) + b[0] = byte(l >> 8) + b[1] = byte(l) + c.secondRead = true + return 2, nil +} + +func (c *testDnsConn) data() ([]byte, error) { + if !c.requested { + builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{RCode: dnsmessage.RCodeNameError}) + res, err := builder.Finish() + if err != nil { + return nil, err + } + return res, nil + } + + msg := dnsmessage.Message{ + Header: dnsmessage.Header{Response: true, Authoritative: true, RCode: dnsmessage.RCodeSuccess, ID: c.reqID}, + Questions: c.questions, + Answers: []dnsmessage.Resource{}, + } + c.appendIPv4(&msg) + c.appendIPV6(&msg) + return msg.Pack() +} + +func (c *testDnsConn) appendIPv4(msg *dnsmessage.Message) { + if c.ipv4 != nil { + msg.Answers = append(msg.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(c.wantName), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AResource{A: [4]byte(c.ipv4)}, + }) + } +} + +func (c *testDnsConn) appendIPV6(msg *dnsmessage.Message) { + if c.ipv6 != nil { + msg.Answers = append(msg.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(c.wantName), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AAAAResource{AAAA: [16]byte(c.ipv6)}, + }) + } +} + +// RemoteAddr implements net.Conn. +func (*testDnsConn) RemoteAddr() net.Addr { + panic("unimplemented") +} + +// SetDeadline implements net.Conn. +func (*testDnsConn) SetDeadline(t time.Time) error { + return nil +} + +// SetReadDeadline implements net.Conn. +func (*testDnsConn) SetReadDeadline(t time.Time) error { + return nil +} + +// SetWriteDeadline implements net.Conn. +func (*testDnsConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// Write implements net.Conn. +func (c *testDnsConn) Write(b []byte) (n int, err error) { + var p dnsmessage.Parser + var h dnsmessage.Header + if h, err = p.Start(b[2:]); err != nil { + return 0, err + } + + c.questions, err = p.AllQuestions() + if err != nil { + return 0, err + } + + for _, q := range c.questions { + qStr := q.Name.String() + if qStr != c.wantName { + continue + } + + c.requested = true + c.reqID = h.ID + if err := p.SkipAllQuestions(); err != nil { + return 0, err + } + break + } + + return len(b), nil +} + +type testConnStub struct{} + +// Close implements net.Conn. +func (*testConnStub) Close() error { + return nil +} + +// LocalAddr implements net.Conn. +func (*testConnStub) LocalAddr() net.Addr { + panic("unimplemented") +} + +// Read implements net.Conn. +func (*testConnStub) Read(b []byte) (n int, err error) { + panic("unimplemented") +} + +// RemoteAddr implements net.Conn. +func (*testConnStub) RemoteAddr() net.Addr { + panic("unimplemented") +} + +// SetDeadline implements net.Conn. +func (*testConnStub) SetDeadline(t time.Time) error { + panic("unimplemented") +} + +// SetReadDeadline implements net.Conn. +func (*testConnStub) SetReadDeadline(t time.Time) error { + panic("unimplemented") +} + +// SetWriteDeadline implements net.Conn. +func (*testConnStub) SetWriteDeadline(t time.Time) error { + panic("unimplemented") +} + +// Write implements net.Conn. +func (*testConnStub) Write(b []byte) (n int, err error) { + panic("unimplemented") +} diff --git a/go.mod b/go.mod index 6ee639f..d6015ad 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,8 @@ require ( github.com/stretchr/testify v1.8.4 github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netns v0.0.4 - golang.org/x/sys v0.2.0 + golang.org/x/net v0.17.0 + golang.org/x/sys v0.13.0 ) require ( diff --git a/go.sum b/go.sum index 83fb4e7..899346c 100644 --- a/go.sum +++ b/go.sum @@ -9,9 +9,11 @@ github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYp github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/health_integration_test.go b/health_integration_test.go index 95c893b..fdd048b 100644 --- a/health_integration_test.go +++ b/health_integration_test.go @@ -23,8 +23,13 @@ func Test_NetlinkWatcher(t *testing.T) { addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}} addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}} + result := make(chan net.Addr, 1) d, err := NewDialer(Config{ Subnets: []string{"1.2.30.0/23"}, + DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { + result <- d.LocalAddr + return nil, nil + }, }) require.NoError(t, err) @@ -32,12 +37,6 @@ func Test_NetlinkWatcher(t *testing.T) { require.NoError(t, w.Start()) t.Cleanup(w.Stop) - result := make(chan net.Addr, 1) - d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { - result <- d.LocalAddr - return nil, nil - } - checkDialAddr(t, d, result, addr1) checkDialAddr(t, d, result, addr1) @@ -64,8 +63,13 @@ func Test_NetlinkWatcher(t *testing.T) { addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}} addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}} + result := make(chan net.Addr, 1) d, err := NewDialer(Config{ Subnets: []string{"1.2.30.0/23"}, + DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { + result <- d.LocalAddr + return nil, nil + }, }) require.NoError(t, err) @@ -73,12 +77,6 @@ func Test_NetlinkWatcher(t *testing.T) { require.NoError(t, w.Start()) t.Cleanup(w.Stop) - result := make(chan net.Addr, 1) - d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { - result <- d.LocalAddr - return nil, nil - } - checkDialAddr(t, d, result, addr1) checkDialAddr(t, d, result, addr1) @@ -108,9 +106,14 @@ func Test_NetlinkWatcher(t *testing.T) { addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}} addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}} + result := make(chan net.Addr, 1) d, err := NewDialer(Config{ Subnets: []string{"1.2.30.0/23"}, Balancer: BalancerTypeRoundRobin, + DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { + result <- d.LocalAddr + return nil, nil + }, }) require.NoError(t, err) @@ -118,12 +121,6 @@ func Test_NetlinkWatcher(t *testing.T) { require.NoError(t, w.Start()) t.Cleanup(w.Stop) - result := make(chan net.Addr, 1) - d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) { - result <- d.LocalAddr - return nil, nil - } - checkDialAddr(t, d, result, addr2) checkDialAddr(t, d, result, addr1) checkDialAddr(t, d, result, addr2) diff --git a/interface.go b/interface.go new file mode 100644 index 0000000..0d9677a --- /dev/null +++ b/interface.go @@ -0,0 +1,33 @@ +package multinet + +import "net" + +// Interface provides information about net.Interface. +type Interface interface { + Name() string + Addrs() ([]net.Addr, error) +} + +type netInterface struct { + iface net.Interface +} + +func (i *netInterface) Name() string { + return i.iface.Name +} + +func (i *netInterface) Addrs() ([]net.Addr, error) { + return i.iface.Addrs() +} + +func systemInterfaces() ([]Interface, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + var result []Interface + for _, iface := range ifaces { + result = append(result, &netInterface{iface: iface}) + } + return result, nil +}