Support hostnames #6
9 changed files with 592 additions and 35 deletions
7
Makefile
7
Makefile
|
@ -1,3 +1,6 @@
|
|||
test:
|
||||
integration-test:
|
||||
# TODO figure out needed capabilities
|
||||
sudo go test -count=1 -v ./...
|
||||
sudo go test -count=1 -v ./... -tags=integration
|
||||
|
||||
test:
|
||||
go test -count=1 -v ./...
|
||||
|
|
|
@ -37,7 +37,7 @@ func (r *roundRobin) DialContext(ctx context.Context, s *Subnet, network, addres
|
|||
|
||||
dd := r.d.dialer
|
||||
dd.LocalAddr = ii.LocalAddr
|
||||
return r.d.dialContext(&dd, ctx, "tcp", address)
|
||||
return r.d.dialContext(&dd, ctx, network, address)
|
||||
}
|
||||
return nil, errors.New("(*roundRobin).DialContext: no suitale node found")
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func (r *firstEnabled) DialContext(ctx context.Context, s *Subnet, network, addr
|
|||
|
||||
dd := r.d.dialer
|
||||
dd.LocalAddr = ii.LocalAddr
|
||||
return r.d.dialContext(&dd, ctx, "tcp", address)
|
||||
return r.d.dialContext(&dd, ctx, network, address)
|
||||
}
|
||||
return nil, errors.New("(*firstEnabled).DialContext: no suitale node found")
|
||||
}
|
||||
|
|
235
dialer.go
235
dialer.go
|
@ -8,6 +8,11 @@ import (
|
|||
"net/netip"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFallbackDelay = 300 * time.Millisecond
|
||||
)
|
||||
|
||||
// Dialer contains the single most important method from the net.Dialer.
|
||||
|
@ -36,8 +41,12 @@ type dialer struct {
|
|||
restrict bool
|
||||
// Algorithm to which picks source address for each ip.
|
||||
balancer balancer
|
||||
// Hook used in tests, overrides `net.Dialer.DialContext()`
|
||||
testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
|
||||
// Overrides `net.Dialer.DialContext()` if specified.
|
||||
customDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
|
||||
// Hostname resolver
|
||||
resolver net.Resolver
|
||||
// See Config.FallbackDelay description.
|
||||
fallbackDelay time.Duration
|
||||
}
|
||||
|
||||
// Subnet represents a single subnet, possibly routable from multiple interfaces.
|
||||
|
@ -65,16 +74,37 @@ type Config struct {
|
|||
Dialer net.Dialer
|
||||
// Balancer specifies algorithm used to pick source address.
|
||||
Balancer BalancerType
|
||||
// FallbackDelay specifies the length of time to wait before
|
||||
// spawning a RFC 6555 Fast Fallback connection. That is, this
|
||||
// is the amount of time to wait for IPv6 to succeed before
|
||||
// assuming that IPv6 is misconfigured and falling back to
|
||||
// IPv4.
|
||||
//
|
||||
// If zero, a default delay of 300ms is used.
|
||||
// A negative value disables Fast Fallback support.
|
||||
FallbackDelay time.Duration
|
||||
// InterfaceSource is custom `Interface`` source.
|
||||
// If not specified, default implementation is used (`net.Interfaces()``).
|
||||
InterfaceSource func() ([]Interface, error)
|
||||
// DialContext is custom DialContext function.
|
||||
// If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`).
|
||||
DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// NewDialer ...
|
||||
func NewDialer(c Config) (Multidialer, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
var ifaces []Interface
|
||||
var err error
|
||||
if c.InterfaceSource != nil {
|
||||
ifaces, err = c.InterfaceSource()
|
||||
} else {
|
||||
ifaces, err = systemInterfaces()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(ifaces, func(i, j int) bool {
|
||||
return ifaces[i].Name < ifaces[j].Name
|
||||
return ifaces[i].Name() < ifaces[j].Name()
|
||||
})
|
||||
|
||||
var sources []iface
|
||||
|
@ -105,16 +135,25 @@ func NewDialer(c Config) (Multidialer, error) {
|
|||
}
|
||||
|
||||
d.restrict = c.Restrict
|
||||
d.resolver.Dial = d.dialContextIP
|
||||
d.fallbackDelay = c.FallbackDelay
|
||||
if d.fallbackDelay == 0 {
|
||||
d.fallbackDelay = defaultFallbackDelay
|
||||
}
|
||||
d.dialer = c.Dialer
|
||||
if c.DialContext != nil {
|
||||
d.customDialContext = c.DialContext
|
||||
}
|
||||
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
type iface struct {
|
||||
info net.Interface
|
||||
name string
|
||||
addrs []netip.Prefix
|
||||
}
|
||||
|
||||
func processIface(info net.Interface) (iface, error) {
|
||||
func processIface(info Interface) (iface, error) {
|
||||
ips, err := info.Addrs()
|
||||
if err != nil {
|
||||
return iface{}, err
|
||||
|
@ -129,7 +168,7 @@ func processIface(info net.Interface) (iface, error) {
|
|||
|
||||
addrs = append(addrs, p)
|
||||
}
|
||||
return iface{info: info, addrs: addrs}, nil
|
||||
return iface{name: info.Name(), addrs: addrs}, nil
|
||||
}
|
||||
|
||||
func processSubnet(subnet string, sources []iface) (Subnet, error) {
|
||||
|
@ -144,7 +183,7 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) {
|
|||
src := source.addrs[i].Addr()
|
||||
if s.Contains(src) {
|
||||
ifs = append(ifs, Source{
|
||||
Name: source.info.Name,
|
||||
Name: source.name,
|
||||
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
|
||||
})
|
||||
}
|
||||
|
@ -168,10 +207,164 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) {
|
|||
// Hostnames for address are currently not supported.
|
||||
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
addr, err := netip.ParseAddrPort(address)
|
||||
if err != nil { //try resolve as hostname
|
||||
return d.dialContextHostname(ctx, network, address)
|
||||
}
|
||||
return d.dialAddr(ctx, network, address, addr)
|
||||
}
|
||||
|
||||
func (d *dialer) dialContextIP(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
addr, err := netip.ParseAddrPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.dialAddr(ctx, network, address, addr)
|
||||
}
|
||||
|
||||
func (d *dialer) dialContextHostname(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// https://github.com/golang/go/blob/release-branch.go1.21/src/net/dial.go#L488
|
||||
addrPorts, err := d.resolveHostIPs(ctx, address, network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var primaries, fallbacks []netip.AddrPort
|
||||
if d.fallbackDelay >= 0 && network == "tcp" {
|
||||
primaries, fallbacks = splitByType(addrPorts)
|
||||
} else {
|
||||
primaries = addrPorts
|
||||
}
|
||||
|
||||
return d.dialParallel(ctx, network, primaries, fallbacks)
|
||||
}
|
||||
|
||||
func (d *dialer) dialSerial(ctx context.Context, network string, addrs []netip.AddrPort) (net.Conn, error) {
|
||||
var firstErr error
|
||||
for _, addr := range addrs {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := d.dialAddr(ctx, network, addr.String(), addr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to connect any resolved address: %w", firstErr)
|
||||
}
|
||||
|
||||
func (d *dialer) dialParallel(ctx context.Context, network string, primaries, fallbacks []netip.AddrPort) (net.Conn, error) {
|
||||
if len(fallbacks) == 0 {
|
||||
return d.dialSerial(ctx, network, primaries)
|
||||
}
|
||||
|
||||
returned := make(chan struct{})
|
||||
defer close(returned)
|
||||
|
||||
type dialResult struct {
|
||||
net.Conn
|
||||
error
|
||||
primary bool
|
||||
done bool
|
||||
}
|
||||
results := make(chan dialResult)
|
||||
|
||||
dialerFunc := func(ctx context.Context, primary bool) {
|
||||
addrs := primaries
|
||||
if !primary {
|
||||
addrs = fallbacks
|
||||
}
|
||||
c, err := d.dialSerial(ctx, network, addrs)
|
||||
select {
|
||||
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
|
||||
case <-returned:
|
||||
if c != nil {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var primary, fallback dialResult
|
||||
|
||||
primaryCtx, primaryCancel := context.WithCancel(ctx)
|
||||
defer primaryCancel()
|
||||
go dialerFunc(primaryCtx, true)
|
||||
|
||||
fallbackTimer := time.NewTimer(d.fallbackDelay)
|
||||
defer fallbackTimer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-fallbackTimer.C:
|
||||
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
|
||||
defer fallbackCancel()
|
||||
|
||||
go dialerFunc(fallbackCtx, false)
|
||||
|
||||
case res := <-results:
|
||||
if res.error == nil {
|
||||
return res.Conn, nil
|
||||
}
|
||||
if res.primary {
|
||||
primary = res
|
||||
} else {
|
||||
fallback = res
|
||||
}
|
||||
if primary.done && fallback.done {
|
||||
return nil, primary.error
|
||||
}
|
||||
if res.primary && fallbackTimer.Stop() {
|
||||
// If we were able to stop the timer, that means it
|
||||
// was running (hadn't yet started the fallback), but
|
||||
// we just got an error on the primary path, so start
|
||||
// the fallback immediately (in 0 nanoseconds).
|
||||
fallbackTimer.Reset(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dialer) resolveHostIPs(ctx context.Context, address string, network string) ([]netip.AddrPort, error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid address format: %w", err)
|
||||
}
|
||||
|
||||
portnum, err := d.resolver.LookupPort(ctx, network, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port format: %w", err)
|
||||
}
|
||||
|
||||
ips, err := d.resolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve host address: %w", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("failed to resolve address for [%s]%s", network, address)
|
||||
}
|
||||
|
||||
var ipAddrs []netip.Addr
|
||||
for _, ip := range ips {
|
||||
ipAddr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ip address '%s': %w", ip, err)
|
||||
}
|
||||
ipAddrs = append(ipAddrs, ipAddr)
|
||||
}
|
||||
|
||||
var addrPorts []netip.AddrPort
|
||||
for _, ipAddr := range ipAddrs {
|
||||
addrPorts = append(addrPorts, netip.AddrPortFrom(ipAddr, uint16(portnum)))
|
||||
}
|
||||
return addrPorts, nil
|
||||
}
|
||||
|
||||
func (d *dialer) dialAddr(ctx context.Context, network, address string, addr netip.AddrPort) (net.Conn, error) {
|
||||
d.mtx.RLock()
|
||||
defer d.mtx.RUnlock()
|
||||
|
||||
|
@ -184,14 +377,14 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (net.
|
|||
if d.restrict {
|
||||
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
|
||||
}
|
||||
return d.dialer.DialContext(ctx, network, address)
|
||||
return d.dialContext(&d.dialer, ctx, network, address)
|
||||
}
|
||||
|
||||
func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if h := d.testHookDialContext; h != nil {
|
||||
return h(nd, ctx, "tcp", address)
|
||||
if h := d.customDialContext; h != nil {
|
||||
return h(nd, ctx, network, address)
|
||||
}
|
||||
return nd.DialContext(ctx, "tcp", address)
|
||||
return nd.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// UpdateInterface implements the Multidialer interface.
|
||||
|
@ -217,3 +410,21 @@ func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitByType divides an address list into two categories:
|
||||
// the first address, and any with same type, are returned as
|
||||
// primaries, while addresses with the opposite type are returned
|
||||
// as fallbacks.
|
||||
func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks []netip.AddrPort) {
|
||||
var primaryLabel bool
|
||||
for i, addr := range addrs {
|
||||
label := addr.Addr().Is4()
|
||||
if i == 0 || label == primaryLabel {
|
||||
primaryLabel = label
|
||||
primaries = append(primaries, addr)
|
||||
} else {
|
||||
fallbacks = append(fallbacks, addr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
308
dialer_hostname_test.go
Normal file
308
dialer_hostname_test.go
Normal file
|
@ -0,0 +1,308 @@
|
|||
package multinet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
func TestHostnameResolveIPv4(t *testing.T) {
|
||||
resolvedAddr := "10.11.12.180:8080"
|
||||
resolved := false
|
||||
d, err := NewDialer(Config{
|
||||
Subnets: []string{"10.11.12.0/24"},
|
||||
InterfaceSource: testInterfacesV4,
|
||||
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if resolvedAddr == address {
|
||||
resolved = true
|
||||
return &testConnStub{}, nil
|
||||
}
|
||||
if network == "udp" {
|
||||
return &testDnsConn{
|
||||
wantName: "s01.storage.devnev.",
|
||||
ipv4: []byte{10, 11, 12, 180},
|
||||
}, nil
|
||||
}
|
||||
panic("unexpected call")
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
require.True(t, resolved)
|
||||
}
|
||||
|
||||
func TestHostnameResolveIPv6(t *testing.T) {
|
||||
resolvedAddr := "[2001:db8:85a3:8d3:1319:8a2e:370:8195]:8080"
|
||||
ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195")
|
||||
resolved := false
|
||||
d, err := NewDialer(Config{
|
||||
Subnets: []string{"2001:db8:85a3:8d3::/64"},
|
||||
InterfaceSource: testInterfacesV6,
|
||||
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if resolvedAddr == address {
|
||||
resolved = true
|
||||
return &testConnStub{}, nil
|
||||
}
|
||||
if network == "udp" {
|
||||
return &testDnsConn{
|
||||
wantName: "s01.storage.devnev.",
|
||||
ipv6: ipv6,
|
||||
}, nil
|
||||
}
|
||||
panic("unexpected call")
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
require.True(t, resolved)
|
||||
}
|
||||
|
||||
func testInterfacesV4() ([]Interface, error) {
|
||||
return []Interface{
|
||||
&testInterface{
|
||||
name: "data1",
|
||||
addrs: []net.Addr{
|
||||
&testAddr{
|
||||
network: "tcp",
|
||||
str: "10.11.12.101/24",
|
||||
},
|
||||
},
|
||||
},
|
||||
&testInterface{
|
||||
name: "data2",
|
||||
addrs: []net.Addr{
|
||||
&testAddr{
|
||||
network: "tcp",
|
||||
str: "10.11.12.102/24",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func testInterfacesV6() ([]Interface, error) {
|
||||
return []Interface{
|
||||
&testInterface{
|
||||
name: "data1",
|
||||
addrs: []net.Addr{
|
||||
&testAddr{
|
||||
network: "tcp",
|
||||
str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64",
|
||||
},
|
||||
},
|
||||
},
|
||||
&testInterface{
|
||||
name: "data2",
|
||||
addrs: []net.Addr{
|
||||
&testAddr{
|
||||
network: "tcp",
|
||||
str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type testInterface struct {
|
||||
name string
|
||||
addrs []net.Addr
|
||||
}
|
||||
|
||||
func (i *testInterface) Name() string { return i.name }
|
||||
func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil }
|
||||
|
||||
type testAddr struct {
|
||||
network string
|
||||
str string
|
||||
}
|
||||
|
||||
func (a *testAddr) Network() string { return a.network }
|
||||
func (a *testAddr) String() string { return a.str }
|
||||
|
||||
type testDnsConn struct {
|
||||
wantName string
|
||||
ipv4 []byte
|
||||
ipv6 []byte
|
||||
reqID uint16
|
||||
requested bool
|
||||
secondRead bool
|
||||
questions []dnsmessage.Question
|
||||
}
|
||||
|
||||
// Close implements net.Conn.
|
||||
func (*testDnsConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr implements net.Conn.
|
||||
func (*testDnsConn) LocalAddr() net.Addr {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Read implements net.Conn.
|
||||
func (c *testDnsConn) Read(b []byte) (n int, err error) {
|
||||
if c.secondRead {
|
||||
data, err := c.data()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
copy(b, data)
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
data, err := c.data()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
l := len(data)
|
||||
b[0] = byte(l >> 8)
|
||||
b[1] = byte(l)
|
||||
c.secondRead = true
|
||||
return 2, nil
|
||||
}
|
||||
|
||||
func (c *testDnsConn) data() ([]byte, error) {
|
||||
if !c.requested {
|
||||
builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{RCode: dnsmessage.RCodeNameError})
|
||||
res, err := builder.Finish()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
msg := dnsmessage.Message{
|
||||
Header: dnsmessage.Header{Response: true, Authoritative: true, RCode: dnsmessage.RCodeSuccess, ID: c.reqID},
|
||||
Questions: c.questions,
|
||||
Answers: []dnsmessage.Resource{},
|
||||
}
|
||||
c.appendIPv4(&msg)
|
||||
c.appendIPV6(&msg)
|
||||
return msg.Pack()
|
||||
}
|
||||
|
||||
func (c *testDnsConn) appendIPv4(msg *dnsmessage.Message) {
|
||||
if c.ipv4 != nil {
|
||||
msg.Answers = append(msg.Answers, dnsmessage.Resource{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: dnsmessage.MustNewName(c.wantName),
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
},
|
||||
Body: &dnsmessage.AResource{A: [4]byte(c.ipv4)},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *testDnsConn) appendIPV6(msg *dnsmessage.Message) {
|
||||
if c.ipv6 != nil {
|
||||
msg.Answers = append(msg.Answers, dnsmessage.Resource{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: dnsmessage.MustNewName(c.wantName),
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
},
|
||||
Body: &dnsmessage.AAAAResource{AAAA: [16]byte(c.ipv6)},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RemoteAddr implements net.Conn.
|
||||
func (*testDnsConn) RemoteAddr() net.Addr {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// SetDeadline implements net.Conn.
|
||||
func (*testDnsConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements net.Conn.
|
||||
func (*testDnsConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements net.Conn.
|
||||
func (*testDnsConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write implements net.Conn.
|
||||
func (c *testDnsConn) Write(b []byte) (n int, err error) {
|
||||
var p dnsmessage.Parser
|
||||
var h dnsmessage.Header
|
||||
if h, err = p.Start(b[2:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
c.questions, err = p.AllQuestions()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for _, q := range c.questions {
|
||||
qStr := q.Name.String()
|
||||
if qStr != c.wantName {
|
||||
continue
|
||||
}
|
||||
|
||||
c.requested = true
|
||||
c.reqID = h.ID
|
||||
if err := p.SkipAllQuestions(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
type testConnStub struct{}
|
||||
|
||||
// Close implements net.Conn.
|
||||
func (*testConnStub) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr implements net.Conn.
|
||||
func (*testConnStub) LocalAddr() net.Addr {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Read implements net.Conn.
|
||||
func (*testConnStub) Read(b []byte) (n int, err error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// RemoteAddr implements net.Conn.
|
||||
func (*testConnStub) RemoteAddr() net.Addr {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// SetDeadline implements net.Conn.
|
||||
func (*testConnStub) SetDeadline(t time.Time) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// SetReadDeadline implements net.Conn.
|
||||
func (*testConnStub) SetReadDeadline(t time.Time) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements net.Conn.
|
||||
func (*testConnStub) SetWriteDeadline(t time.Time) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Write implements net.Conn.
|
||||
func (*testConnStub) Write(b []byte) (n int, err error) {
|
||||
panic("unimplemented")
|
||||
}
|
|
@ -1,3 +1,5 @@
|
|||
//go:build integration
|
||||
|
||||
package multinet
|
||||
|
||||
import (
|
3
go.mod
3
go.mod
|
@ -6,7 +6,8 @@ require (
|
|||
github.com/stretchr/testify v1.8.4
|
||||
github.com/vishvananda/netlink v1.1.0
|
||||
github.com/vishvananda/netns v0.0.4
|
||||
golang.org/x/sys v0.2.0
|
||||
golang.org/x/net v0.17.0
|
||||
golang.org/x/sys v0.13.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
|
BIN
go.sum
BIN
go.sum
Binary file not shown.
|
@ -1,3 +1,5 @@
|
|||
//go:build integration
|
||||
|
||||
package multinet
|
||||
|
||||
import (
|
||||
|
@ -21,8 +23,13 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
|
||||
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d, err := NewDialer(Config{
|
||||
Subnets: []string{"1.2.30.0/23"},
|
||||
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -30,12 +37,6 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
require.NoError(t, w.Start())
|
||||
t.Cleanup(w.Stop)
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
checkDialAddr(t, d, result, addr1)
|
||||
checkDialAddr(t, d, result, addr1)
|
||||
|
||||
|
@ -62,8 +63,13 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
|
||||
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d, err := NewDialer(Config{
|
||||
Subnets: []string{"1.2.30.0/23"},
|
||||
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -71,12 +77,6 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
require.NoError(t, w.Start())
|
||||
t.Cleanup(w.Stop)
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
checkDialAddr(t, d, result, addr1)
|
||||
checkDialAddr(t, d, result, addr1)
|
||||
|
||||
|
@ -106,9 +106,14 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
|
||||
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d, err := NewDialer(Config{
|
||||
Subnets: []string{"1.2.30.0/23"},
|
||||
Balancer: BalancerTypeRoundRobin,
|
||||
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -116,12 +121,6 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
require.NoError(t, w.Start())
|
||||
t.Cleanup(w.Stop)
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
checkDialAddr(t, d, result, addr2)
|
||||
checkDialAddr(t, d, result, addr1)
|
||||
checkDialAddr(t, d, result, addr2)
|
33
interface.go
Normal file
33
interface.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package multinet
|
||||
|
||||
import "net"
|
||||
|
||||
// Interface provides information about net.Interface.
|
||||
type Interface interface {
|
||||
Name() string
|
||||
Addrs() ([]net.Addr, error)
|
||||
}
|
||||
|
||||
type netInterface struct {
|
||||
iface net.Interface
|
||||
}
|
||||
|
||||
func (i *netInterface) Name() string {
|
||||
return i.iface.Name
|
||||
}
|
||||
|
||||
func (i *netInterface) Addrs() ([]net.Addr, error) {
|
||||
return i.iface.Addrs()
|
||||
}
|
||||
|
||||
func systemInterfaces() ([]Interface, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result []Interface
|
||||
for _, iface := range ifaces {
|
||||
result = append(result, &netInterface{iface: iface})
|
||||
}
|
||||
return result, nil
|
||||
}
|
Loading…
Reference in a new issue
I don't know whether it was your idea or this is copy-pasted but seems really good