Support hostnames #6

Merged
fyrchik merged 4 commits from dstepanov-yadro/multinet:feat/host_names into master 2024-09-04 19:51:22 +00:00
9 changed files with 596 additions and 37 deletions

View file

@ -1,3 +1,6 @@
test:
integration-test:
# TODO figure out needed capabilities
sudo go test -count=1 -v ./...
sudo go test -count=1 -v ./... -tags=integration
test:
go test -count=1 -v ./...

View file

@ -37,7 +37,7 @@ func (r *roundRobin) DialContext(ctx context.Context, s *Subnet, network, addres
dd := r.d.dialer
dd.LocalAddr = ii.LocalAddr
return r.d.dialContext(&dd, ctx, "tcp", address)
return r.d.dialContext(&dd, ctx, network, address)
}
return nil, errors.New("(*roundRobin).DialContext: no suitale node found")
}
@ -55,7 +55,7 @@ func (r *firstEnabled) DialContext(ctx context.Context, s *Subnet, network, addr
dd := r.d.dialer
dd.LocalAddr = ii.LocalAddr
return r.d.dialContext(&dd, ctx, "tcp", address)
return r.d.dialContext(&dd, ctx, network, address)
}
return nil, errors.New("(*firstEnabled).DialContext: no suitale node found")
}

235
dialer.go
View file

@ -8,6 +8,11 @@ import (
"net/netip"
"sort"
"sync"
"time"
)
const (
defaultFallbackDelay = 300 * time.Millisecond
)
// Dialer contains the single most important method from the net.Dialer.
@ -36,8 +41,12 @@ type dialer struct {
restrict bool
// Algorithm to which picks source address for each ip.
balancer balancer
// Hook used in tests, overrides `net.Dialer.DialContext()`
testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
// Overrides `net.Dialer.DialContext()` if specified.
customDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
// Hostname resolver
resolver net.Resolver
// See Config.FallbackDelay description.
fallbackDelay time.Duration
}
// Subnet represents a single subnet, possibly routable from multiple interfaces.
@ -65,16 +74,37 @@ type Config struct {
Dialer net.Dialer
// Balancer specifies algorithm used to pick source address.
Balancer BalancerType
// FallbackDelay specifies the length of time to wait before
// spawning a RFC 6555 Fast Fallback connection. That is, this
// is the amount of time to wait for IPv6 to succeed before
// assuming that IPv6 is misconfigured and falling back to
// IPv4.
//
// If zero, a default delay of 300ms is used.
// A negative value disables Fast Fallback support.
FallbackDelay time.Duration
// InterfaceSource is custom `Interface`` source.
// If not specified, default implementation is used (`net.Interfaces()``).
InterfaceSource func() ([]Interface, error)
fyrchik marked this conversation as resolved Outdated

Do we have usecase for this besides using in tests?

Do we have usecase for this besides using in tests?

For example tracing, logging, metrics.

For example tracing, logging, metrics.

It is called once during construction, so I doubt it is useful, but ok

It is called once during construction, so I doubt it is useful, but ok
// DialContext is custom DialContext function.
// If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`).
DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
}
// NewDialer ...
func NewDialer(c Config) (Multidialer, error) {
ifaces, err := net.Interfaces()
var ifaces []Interface
var err error
if c.InterfaceSource != nil {
ifaces, err = c.InterfaceSource()
} else {
ifaces, err = systemInterfaces()
}
if err != nil {
return nil, err
}
sort.Slice(ifaces, func(i, j int) bool {
return ifaces[i].Name < ifaces[j].Name
return ifaces[i].Name() < ifaces[j].Name()
})
var sources []iface
@ -105,16 +135,25 @@ func NewDialer(c Config) (Multidialer, error) {
}
d.restrict = c.Restrict
d.resolver.Dial = d.dialContextIP
d.fallbackDelay = c.FallbackDelay
if d.fallbackDelay == 0 {
d.fallbackDelay = defaultFallbackDelay
}
d.dialer = c.Dialer
if c.DialContext != nil {
d.customDialContext = c.DialContext
}
return &d, nil
}
type iface struct {
info net.Interface
name string
addrs []netip.Prefix
}
func processIface(info net.Interface) (iface, error) {
func processIface(info Interface) (iface, error) {
ips, err := info.Addrs()
if err != nil {
return iface{}, err
@ -129,7 +168,7 @@ func processIface(info net.Interface) (iface, error) {
addrs = append(addrs, p)
}
return iface{info: info, addrs: addrs}, nil
return iface{name: info.Name(), addrs: addrs}, nil
}
func processSubnet(subnet string, sources []iface) (Subnet, error) {
@ -144,7 +183,7 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) {
src := source.addrs[i].Addr()
if s.Contains(src) {
ifs = append(ifs, Source{
Name: source.info.Name,
Name: source.name,
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
})
}
@ -168,10 +207,164 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) {
// Hostnames for address are currently not supported.
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
addr, err := netip.ParseAddrPort(address)
if err != nil { //try resolve as hostname
return d.dialContextHostname(ctx, network, address)
}
return d.dialAddr(ctx, network, address, addr)
}
func (d *dialer) dialContextIP(ctx context.Context, network, address string) (net.Conn, error) {
addr, err := netip.ParseAddrPort(address)
if err != nil {
return nil, err
}
return d.dialAddr(ctx, network, address, addr)
}
func (d *dialer) dialContextHostname(ctx context.Context, network, address string) (net.Conn, error) {
// https://github.com/golang/go/blob/release-branch.go1.21/src/net/dial.go#L488
addrPorts, err := d.resolveHostIPs(ctx, address, network)
if err != nil {
return nil, err
}
var primaries, fallbacks []netip.AddrPort

Only tcp or tcp4/tcp6 too?

Only `tcp` or `tcp4`/`tcp6` too?
In stdlib only `tcp` defined: https://github.com/golang/go/blob/master/src/net/dial.go#L502

I was looking at this function https://github.com/golang/go/blob/master/src/net/dial.go#L229, which is called inside resolveAddrList

I was looking at this function https://github.com/golang/go/blob/master/src/net/dial.go#L229, which is called inside `resolveAddrList`
if d.fallbackDelay >= 0 && network == "tcp" {
primaries, fallbacks = splitByType(addrPorts)
} else {
primaries = addrPorts
}
return d.dialParallel(ctx, network, primaries, fallbacks)
}
fyrchik marked this conversation as resolved Outdated

If this is taken from stdlib, could we supply references in comments?

If this is taken from stdlib, could we supply references in comments?

Done

Done
func (d *dialer) dialSerial(ctx context.Context, network string, addrs []netip.AddrPort) (net.Conn, error) {
var firstErr error
for _, addr := range addrs {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
conn, err := d.dialAddr(ctx, network, addr.String(), addr)
if err == nil {
return conn, nil
}
if firstErr == nil {
firstErr = err
}
}
return nil, fmt.Errorf("failed to connect any resolved address: %w", firstErr)
}
func (d *dialer) dialParallel(ctx context.Context, network string, primaries, fallbacks []netip.AddrPort) (net.Conn, error) {
if len(fallbacks) == 0 {
return d.dialSerial(ctx, network, primaries)
}
returned := make(chan struct{})
defer close(returned)
type dialResult struct {
net.Conn
error
primary bool
done bool
}
results := make(chan dialResult)
dialerFunc := func(ctx context.Context, primary bool) {
addrs := primaries
if !primary {
addrs = fallbacks
}
c, err := d.dialSerial(ctx, network, addrs)
select {
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
case <-returned:
if c != nil {
c.Close()
}
}
}
var primary, fallback dialResult
primaryCtx, primaryCancel := context.WithCancel(ctx)
defer primaryCancel()
go dialerFunc(primaryCtx, true)
fallbackTimer := time.NewTimer(d.fallbackDelay)
defer fallbackTimer.Stop()
for {
select {
case <-fallbackTimer.C:
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
Review

I don't know whether it was your idea or this is copy-pasted but seems really good

I don't know whether it was your idea or this is copy-pasted but seems really good
go dialerFunc(fallbackCtx, false)
case res := <-results:
if res.error == nil {
return res.Conn, nil
}
if res.primary {
primary = res
} else {
fallback = res
}
if primary.done && fallback.done {
return nil, primary.error
}
if res.primary && fallbackTimer.Stop() {
// If we were able to stop the timer, that means it
// was running (hadn't yet started the fallback), but
// we just got an error on the primary path, so start
// the fallback immediately (in 0 nanoseconds).
fallbackTimer.Reset(0)
}
}
}
}
func (d *dialer) resolveHostIPs(ctx context.Context, address string, network string) ([]netip.AddrPort, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("invalid address format: %w", err)
}
portnum, err := d.resolver.LookupPort(ctx, network, port)
if err != nil {
return nil, fmt.Errorf("invalid port format: %w", err)
}
ips, err := d.resolver.LookupHost(ctx, host)
if err != nil {
return nil, fmt.Errorf("failed to resolve host address: %w", err)
}
if len(ips) == 0 {
return nil, fmt.Errorf("failed to resolve address for [%s]%s", network, address)
}
var ipAddrs []netip.Addr
for _, ip := range ips {
ipAddr, err := netip.ParseAddr(ip)
if err != nil {
return nil, fmt.Errorf("failed to parse ip address '%s': %w", ip, err)
}
ipAddrs = append(ipAddrs, ipAddr)
}
var addrPorts []netip.AddrPort
for _, ipAddr := range ipAddrs {
addrPorts = append(addrPorts, netip.AddrPortFrom(ipAddr, uint16(portnum)))
}
return addrPorts, nil
}
func (d *dialer) dialAddr(ctx context.Context, network, address string, addr netip.AddrPort) (net.Conn, error) {
d.mtx.RLock()
defer d.mtx.RUnlock()
@ -184,14 +377,14 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (net.
if d.restrict {
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
}
return d.dialer.DialContext(ctx, network, address)
return d.dialContext(&d.dialer, ctx, network, address)
}
func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
fyrchik marked this conversation as resolved Outdated

What was wrong with testHook? It is in spirit of the related stdlib pieces b5f87b5407/src/net/dial.go (L423)

What was wrong with `testHook`? It is in spirit of the related stdlib pieces https://github.com/golang/go/blob/b5f87b5407916c4049a3158cc944cebfd7a883a9/src/net/dial.go#L423

Nothing wrong. But I guess it could be used not only for tests. For example tracing, logging, metrics.

Nothing wrong. But I guess it could be used not only for tests. For example tracing, logging, metrics.
if h := d.testHookDialContext; h != nil {
return h(nd, ctx, "tcp", address)
if h := d.customDialContext; h != nil {
return h(nd, ctx, network, address)
}
return nd.DialContext(ctx, "tcp", address)
return nd.DialContext(ctx, network, address)
}
// UpdateInterface implements the Multidialer interface.
@ -217,3 +410,21 @@ func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error {
}
return nil
}
// splitByType divides an address list into two categories:
// the first address, and any with same type, are returned as
// primaries, while addresses with the opposite type are returned
// as fallbacks.
func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks []netip.AddrPort) {
var primaryLabel bool
for i, addr := range addrs {
label := addr.Addr().Is4()
if i == 0 || label == primaryLabel {
primaryLabel = label
primaries = append(primaries, addr)
} else {
fallbacks = append(fallbacks, addr)
}
}
return
}

308
dialer_hostname_test.go Normal file
View file

@ -0,0 +1,308 @@
package multinet
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/net/dns/dnsmessage"
)
func TestHostnameResolveIPv4(t *testing.T) {
resolvedAddr := "10.11.12.180:8080"
resolved := false
d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24"},
InterfaceSource: testInterfacesV4,
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
return &testConnStub{}, nil
}
if network == "udp" {
return &testDnsConn{
wantName: "s01.storage.devnev.",
ipv4: []byte{10, 11, 12, 180},
}, nil
}
panic("unexpected call")
},
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
require.NoError(t, err)
require.NoError(t, conn.Close())
require.True(t, resolved)
}
func TestHostnameResolveIPv6(t *testing.T) {
resolvedAddr := "[2001:db8:85a3:8d3:1319:8a2e:370:8195]:8080"
ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195")
resolved := false
d, err := NewDialer(Config{
Subnets: []string{"2001:db8:85a3:8d3::/64"},
InterfaceSource: testInterfacesV6,
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
return &testConnStub{}, nil
}
if network == "udp" {
return &testDnsConn{
wantName: "s01.storage.devnev.",
ipv6: ipv6,
}, nil
}
panic("unexpected call")
},
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
require.NoError(t, err)
require.NoError(t, conn.Close())
require.True(t, resolved)
}
func testInterfacesV4() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.101/24",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.102/24",
},
},
},
}, nil
}
func testInterfacesV6() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64",
},
},
},
}, nil
}
type testInterface struct {
name string
addrs []net.Addr
}
func (i *testInterface) Name() string { return i.name }
func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil }
type testAddr struct {
network string
str string
}
func (a *testAddr) Network() string { return a.network }
func (a *testAddr) String() string { return a.str }
type testDnsConn struct {
wantName string
ipv4 []byte
ipv6 []byte
reqID uint16
requested bool
secondRead bool
questions []dnsmessage.Question
}
// Close implements net.Conn.
func (*testDnsConn) Close() error {
return nil
}
// LocalAddr implements net.Conn.
func (*testDnsConn) LocalAddr() net.Addr {
panic("unimplemented")
}
// Read implements net.Conn.
func (c *testDnsConn) Read(b []byte) (n int, err error) {
if c.secondRead {
data, err := c.data()
if err != nil {
return 0, err
}
copy(b, data)
return len(data), nil
}
data, err := c.data()
if err != nil {
return 0, err
}
l := len(data)
b[0] = byte(l >> 8)
b[1] = byte(l)
c.secondRead = true
return 2, nil
}
func (c *testDnsConn) data() ([]byte, error) {
if !c.requested {
builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{RCode: dnsmessage.RCodeNameError})
res, err := builder.Finish()
if err != nil {
return nil, err
}
return res, nil
}
msg := dnsmessage.Message{
Header: dnsmessage.Header{Response: true, Authoritative: true, RCode: dnsmessage.RCodeSuccess, ID: c.reqID},
Questions: c.questions,
Answers: []dnsmessage.Resource{},
}
c.appendIPv4(&msg)
c.appendIPV6(&msg)
return msg.Pack()
}
func (c *testDnsConn) appendIPv4(msg *dnsmessage.Message) {
if c.ipv4 != nil {
msg.Answers = append(msg.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(c.wantName),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AResource{A: [4]byte(c.ipv4)},
})
}
}
func (c *testDnsConn) appendIPV6(msg *dnsmessage.Message) {
if c.ipv6 != nil {
msg.Answers = append(msg.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(c.wantName),
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AAAAResource{AAAA: [16]byte(c.ipv6)},
})
}
}
// RemoteAddr implements net.Conn.
func (*testDnsConn) RemoteAddr() net.Addr {
panic("unimplemented")
}
// SetDeadline implements net.Conn.
func (*testDnsConn) SetDeadline(t time.Time) error {
return nil
}
// SetReadDeadline implements net.Conn.
func (*testDnsConn) SetReadDeadline(t time.Time) error {
return nil
}
// SetWriteDeadline implements net.Conn.
func (*testDnsConn) SetWriteDeadline(t time.Time) error {
return nil
}
// Write implements net.Conn.
func (c *testDnsConn) Write(b []byte) (n int, err error) {
var p dnsmessage.Parser
var h dnsmessage.Header
if h, err = p.Start(b[2:]); err != nil {
return 0, err
}
c.questions, err = p.AllQuestions()
if err != nil {
return 0, err
}
for _, q := range c.questions {
qStr := q.Name.String()
if qStr != c.wantName {
continue
}
c.requested = true
c.reqID = h.ID
if err := p.SkipAllQuestions(); err != nil {
return 0, err
}
break
}
return len(b), nil
}
type testConnStub struct{}
// Close implements net.Conn.
func (*testConnStub) Close() error {
return nil
}
// LocalAddr implements net.Conn.
func (*testConnStub) LocalAddr() net.Addr {
panic("unimplemented")
}
// Read implements net.Conn.
func (*testConnStub) Read(b []byte) (n int, err error) {
panic("unimplemented")
}
// RemoteAddr implements net.Conn.
func (*testConnStub) RemoteAddr() net.Addr {
panic("unimplemented")
}
// SetDeadline implements net.Conn.
func (*testConnStub) SetDeadline(t time.Time) error {
panic("unimplemented")
}
// SetReadDeadline implements net.Conn.
func (*testConnStub) SetReadDeadline(t time.Time) error {
panic("unimplemented")
}
// SetWriteDeadline implements net.Conn.
func (*testConnStub) SetWriteDeadline(t time.Time) error {
panic("unimplemented")
}
// Write implements net.Conn.
func (*testConnStub) Write(b []byte) (n int, err error) {
panic("unimplemented")
}

View file

@ -1,3 +1,5 @@
//go:build integration
package multinet
import (

3
go.mod
View file

@ -6,7 +6,8 @@ require (
github.com/stretchr/testify v1.8.4
github.com/vishvananda/netlink v1.1.0
github.com/vishvananda/netns v0.0.4
golang.org/x/sys v0.2.0
golang.org/x/net v0.17.0
golang.org/x/sys v0.13.0
)
require (

6
go.sum
View file

@ -9,9 +9,11 @@ github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYp
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View file

@ -1,3 +1,5 @@
//go:build integration
package multinet
import (
@ -21,8 +23,13 @@ func Test_NetlinkWatcher(t *testing.T) {
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
@ -30,12 +37,6 @@ func Test_NetlinkWatcher(t *testing.T) {
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
result := make(chan net.Addr, 1)
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
}
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
@ -62,8 +63,13 @@ func Test_NetlinkWatcher(t *testing.T) {
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
@ -71,12 +77,6 @@ func Test_NetlinkWatcher(t *testing.T) {
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
result := make(chan net.Addr, 1)
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
}
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
@ -106,9 +106,14 @@ func Test_NetlinkWatcher(t *testing.T) {
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
Balancer: BalancerTypeRoundRobin,
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
@ -116,12 +121,6 @@ func Test_NetlinkWatcher(t *testing.T) {
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
result := make(chan net.Addr, 1)
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
}
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr2)

33
interface.go Normal file
View file

@ -0,0 +1,33 @@
package multinet
import "net"
// Interface provides information about net.Interface.
type Interface interface {
Name() string
Addrs() ([]net.Addr, error)
}
type netInterface struct {
iface net.Interface
}
func (i *netInterface) Name() string {
return i.iface.Name
}
func (i *netInterface) Addrs() ([]net.Addr, error) {
return i.iface.Addrs()
}
func systemInterfaces() ([]Interface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var result []Interface
for _, iface := range ifaces {
result = append(result, &netInterface{iface: iface})
}
return result, nil
}