Support hostnames #6
9 changed files with 596 additions and 37 deletions
7
Makefile
7
Makefile
|
@ -1,3 +1,6 @@
|
|||
test:
|
||||
integration-test:
|
||||
# TODO figure out needed capabilities
|
||||
sudo go test -count=1 -v ./...
|
||||
sudo go test -count=1 -v ./... -tags=integration
|
||||
|
||||
test:
|
||||
go test -count=1 -v ./...
|
||||
|
|
|
@ -37,7 +37,7 @@ func (r *roundRobin) DialContext(ctx context.Context, s *Subnet, network, addres
|
|||
|
||||
dd := r.d.dialer
|
||||
dd.LocalAddr = ii.LocalAddr
|
||||
return r.d.dialContext(&dd, ctx, "tcp", address)
|
||||
return r.d.dialContext(&dd, ctx, network, address)
|
||||
}
|
||||
return nil, errors.New("(*roundRobin).DialContext: no suitale node found")
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func (r *firstEnabled) DialContext(ctx context.Context, s *Subnet, network, addr
|
|||
|
||||
dd := r.d.dialer
|
||||
dd.LocalAddr = ii.LocalAddr
|
||||
return r.d.dialContext(&dd, ctx, "tcp", address)
|
||||
return r.d.dialContext(&dd, ctx, network, address)
|
||||
}
|
||||
return nil, errors.New("(*firstEnabled).DialContext: no suitale node found")
|
||||
}
|
||||
|
|
235
dialer.go
235
dialer.go
|
@ -8,6 +8,11 @@ import (
|
|||
"net/netip"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFallbackDelay = 300 * time.Millisecond
|
||||
)
|
||||
|
||||
// Dialer contains the single most important method from the net.Dialer.
|
||||
|
@ -36,8 +41,12 @@ type dialer struct {
|
|||
restrict bool
|
||||
// Algorithm to which picks source address for each ip.
|
||||
balancer balancer
|
||||
// Hook used in tests, overrides `net.Dialer.DialContext()`
|
||||
testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
|
||||
// Overrides `net.Dialer.DialContext()` if specified.
|
||||
customDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
|
||||
// Hostname resolver
|
||||
resolver net.Resolver
|
||||
// See Config.FallbackDelay description.
|
||||
fallbackDelay time.Duration
|
||||
}
|
||||
|
||||
// Subnet represents a single subnet, possibly routable from multiple interfaces.
|
||||
|
@ -65,16 +74,37 @@ type Config struct {
|
|||
Dialer net.Dialer
|
||||
// Balancer specifies algorithm used to pick source address.
|
||||
Balancer BalancerType
|
||||
// FallbackDelay specifies the length of time to wait before
|
||||
// spawning a RFC 6555 Fast Fallback connection. That is, this
|
||||
// is the amount of time to wait for IPv6 to succeed before
|
||||
// assuming that IPv6 is misconfigured and falling back to
|
||||
// IPv4.
|
||||
//
|
||||
// If zero, a default delay of 300ms is used.
|
||||
// A negative value disables Fast Fallback support.
|
||||
FallbackDelay time.Duration
|
||||
// InterfaceSource is custom `Interface`` source.
|
||||
// If not specified, default implementation is used (`net.Interfaces()``).
|
||||
InterfaceSource func() ([]Interface, error)
|
||||
fyrchik marked this conversation as resolved
Outdated
|
||||
// DialContext is custom DialContext function.
|
||||
// If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`).
|
||||
DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// NewDialer ...
|
||||
func NewDialer(c Config) (Multidialer, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
var ifaces []Interface
|
||||
var err error
|
||||
if c.InterfaceSource != nil {
|
||||
ifaces, err = c.InterfaceSource()
|
||||
} else {
|
||||
ifaces, err = systemInterfaces()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(ifaces, func(i, j int) bool {
|
||||
return ifaces[i].Name < ifaces[j].Name
|
||||
return ifaces[i].Name() < ifaces[j].Name()
|
||||
})
|
||||
|
||||
var sources []iface
|
||||
|
@ -105,16 +135,25 @@ func NewDialer(c Config) (Multidialer, error) {
|
|||
}
|
||||
|
||||
d.restrict = c.Restrict
|
||||
d.resolver.Dial = d.dialContextIP
|
||||
d.fallbackDelay = c.FallbackDelay
|
||||
if d.fallbackDelay == 0 {
|
||||
d.fallbackDelay = defaultFallbackDelay
|
||||
}
|
||||
d.dialer = c.Dialer
|
||||
if c.DialContext != nil {
|
||||
d.customDialContext = c.DialContext
|
||||
}
|
||||
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
type iface struct {
|
||||
info net.Interface
|
||||
name string
|
||||
addrs []netip.Prefix
|
||||
}
|
||||
|
||||
func processIface(info net.Interface) (iface, error) {
|
||||
func processIface(info Interface) (iface, error) {
|
||||
ips, err := info.Addrs()
|
||||
if err != nil {
|
||||
return iface{}, err
|
||||
|
@ -129,7 +168,7 @@ func processIface(info net.Interface) (iface, error) {
|
|||
|
||||
addrs = append(addrs, p)
|
||||
}
|
||||
return iface{info: info, addrs: addrs}, nil
|
||||
return iface{name: info.Name(), addrs: addrs}, nil
|
||||
}
|
||||
|
||||
func processSubnet(subnet string, sources []iface) (Subnet, error) {
|
||||
|
@ -144,7 +183,7 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) {
|
|||
src := source.addrs[i].Addr()
|
||||
if s.Contains(src) {
|
||||
ifs = append(ifs, Source{
|
||||
Name: source.info.Name,
|
||||
Name: source.name,
|
||||
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
|
||||
})
|
||||
}
|
||||
|
@ -168,10 +207,164 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) {
|
|||
// Hostnames for address are currently not supported.
|
||||
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
addr, err := netip.ParseAddrPort(address)
|
||||
if err != nil { //try resolve as hostname
|
||||
return d.dialContextHostname(ctx, network, address)
|
||||
}
|
||||
return d.dialAddr(ctx, network, address, addr)
|
||||
}
|
||||
|
||||
func (d *dialer) dialContextIP(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
addr, err := netip.ParseAddrPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.dialAddr(ctx, network, address, addr)
|
||||
}
|
||||
|
||||
func (d *dialer) dialContextHostname(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// https://github.com/golang/go/blob/release-branch.go1.21/src/net/dial.go#L488
|
||||
addrPorts, err := d.resolveHostIPs(ctx, address, network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var primaries, fallbacks []netip.AddrPort
|
||||
fyrchik
commented
Only Only `tcp` or `tcp4`/`tcp6` too?
dstepanov-yadro
commented
In stdlib only In stdlib only `tcp` defined: https://github.com/golang/go/blob/master/src/net/dial.go#L502
fyrchik
commented
I was looking at this function https://github.com/golang/go/blob/master/src/net/dial.go#L229, which is called inside I was looking at this function https://github.com/golang/go/blob/master/src/net/dial.go#L229, which is called inside `resolveAddrList`
|
||||
if d.fallbackDelay >= 0 && network == "tcp" {
|
||||
primaries, fallbacks = splitByType(addrPorts)
|
||||
} else {
|
||||
primaries = addrPorts
|
||||
}
|
||||
|
||||
return d.dialParallel(ctx, network, primaries, fallbacks)
|
||||
}
|
||||
|
||||
fyrchik marked this conversation as resolved
Outdated
fyrchik
commented
If this is taken from stdlib, could we supply references in comments? If this is taken from stdlib, could we supply references in comments?
dstepanov-yadro
commented
Done Done
|
||||
func (d *dialer) dialSerial(ctx context.Context, network string, addrs []netip.AddrPort) (net.Conn, error) {
|
||||
var firstErr error
|
||||
for _, addr := range addrs {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := d.dialAddr(ctx, network, addr.String(), addr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("failed to connect any resolved address: %w", firstErr)
|
||||
}
|
||||
|
||||
func (d *dialer) dialParallel(ctx context.Context, network string, primaries, fallbacks []netip.AddrPort) (net.Conn, error) {
|
||||
if len(fallbacks) == 0 {
|
||||
return d.dialSerial(ctx, network, primaries)
|
||||
}
|
||||
|
||||
returned := make(chan struct{})
|
||||
defer close(returned)
|
||||
|
||||
type dialResult struct {
|
||||
net.Conn
|
||||
error
|
||||
primary bool
|
||||
done bool
|
||||
}
|
||||
results := make(chan dialResult)
|
||||
|
||||
dialerFunc := func(ctx context.Context, primary bool) {
|
||||
addrs := primaries
|
||||
if !primary {
|
||||
addrs = fallbacks
|
||||
}
|
||||
c, err := d.dialSerial(ctx, network, addrs)
|
||||
select {
|
||||
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
|
||||
case <-returned:
|
||||
if c != nil {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var primary, fallback dialResult
|
||||
|
||||
primaryCtx, primaryCancel := context.WithCancel(ctx)
|
||||
defer primaryCancel()
|
||||
go dialerFunc(primaryCtx, true)
|
||||
|
||||
fallbackTimer := time.NewTimer(d.fallbackDelay)
|
||||
defer fallbackTimer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-fallbackTimer.C:
|
||||
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
|
||||
defer fallbackCancel()
|
||||
aarifullin
commented
I don't know whether it was your idea or this is copy-pasted but seems really good I don't know whether it was your idea or this is copy-pasted but seems really good
|
||||
go dialerFunc(fallbackCtx, false)
|
||||
|
||||
case res := <-results:
|
||||
if res.error == nil {
|
||||
return res.Conn, nil
|
||||
}
|
||||
if res.primary {
|
||||
primary = res
|
||||
} else {
|
||||
fallback = res
|
||||
}
|
||||
if primary.done && fallback.done {
|
||||
return nil, primary.error
|
||||
}
|
||||
if res.primary && fallbackTimer.Stop() {
|
||||
// If we were able to stop the timer, that means it
|
||||
// was running (hadn't yet started the fallback), but
|
||||
// we just got an error on the primary path, so start
|
||||
// the fallback immediately (in 0 nanoseconds).
|
||||
fallbackTimer.Reset(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *dialer) resolveHostIPs(ctx context.Context, address string, network string) ([]netip.AddrPort, error) {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid address format: %w", err)
|
||||
}
|
||||
|
||||
portnum, err := d.resolver.LookupPort(ctx, network, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port format: %w", err)
|
||||
}
|
||||
|
||||
ips, err := d.resolver.LookupHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve host address: %w", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("failed to resolve address for [%s]%s", network, address)
|
||||
}
|
||||
|
||||
var ipAddrs []netip.Addr
|
||||
for _, ip := range ips {
|
||||
ipAddr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ip address '%s': %w", ip, err)
|
||||
}
|
||||
ipAddrs = append(ipAddrs, ipAddr)
|
||||
}
|
||||
|
||||
var addrPorts []netip.AddrPort
|
||||
for _, ipAddr := range ipAddrs {
|
||||
addrPorts = append(addrPorts, netip.AddrPortFrom(ipAddr, uint16(portnum)))
|
||||
}
|
||||
return addrPorts, nil
|
||||
}
|
||||
|
||||
func (d *dialer) dialAddr(ctx context.Context, network, address string, addr netip.AddrPort) (net.Conn, error) {
|
||||
d.mtx.RLock()
|
||||
defer d.mtx.RUnlock()
|
||||
|
||||
|
@ -184,14 +377,14 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (net.
|
|||
if d.restrict {
|
||||
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
|
||||
}
|
||||
return d.dialer.DialContext(ctx, network, address)
|
||||
return d.dialContext(&d.dialer, ctx, network, address)
|
||||
}
|
||||
|
||||
func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
|
||||
fyrchik marked this conversation as resolved
Outdated
fyrchik
commented
What was wrong with What was wrong with `testHook`? It is in spirit of the related stdlib pieces https://github.com/golang/go/blob/b5f87b5407916c4049a3158cc944cebfd7a883a9/src/net/dial.go#L423
dstepanov-yadro
commented
Nothing wrong. But I guess it could be used not only for tests. For example tracing, logging, metrics. Nothing wrong. But I guess it could be used not only for tests. For example tracing, logging, metrics.
|
||||
if h := d.testHookDialContext; h != nil {
|
||||
return h(nd, ctx, "tcp", address)
|
||||
if h := d.customDialContext; h != nil {
|
||||
return h(nd, ctx, network, address)
|
||||
}
|
||||
return nd.DialContext(ctx, "tcp", address)
|
||||
return nd.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// UpdateInterface implements the Multidialer interface.
|
||||
|
@ -217,3 +410,21 @@ func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitByType divides an address list into two categories:
|
||||
// the first address, and any with same type, are returned as
|
||||
// primaries, while addresses with the opposite type are returned
|
||||
// as fallbacks.
|
||||
func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks []netip.AddrPort) {
|
||||
var primaryLabel bool
|
||||
for i, addr := range addrs {
|
||||
label := addr.Addr().Is4()
|
||||
if i == 0 || label == primaryLabel {
|
||||
primaryLabel = label
|
||||
primaries = append(primaries, addr)
|
||||
} else {
|
||||
fallbacks = append(fallbacks, addr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
308
dialer_hostname_test.go
Normal file
308
dialer_hostname_test.go
Normal file
|
@ -0,0 +1,308 @@
|
|||
package multinet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
func TestHostnameResolveIPv4(t *testing.T) {
|
||||
resolvedAddr := "10.11.12.180:8080"
|
||||
resolved := false
|
||||
d, err := NewDialer(Config{
|
||||
Subnets: []string{"10.11.12.0/24"},
|
||||
InterfaceSource: testInterfacesV4,
|
||||
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if resolvedAddr == address {
|
||||
resolved = true
|
||||
return &testConnStub{}, nil
|
||||
}
|
||||
if network == "udp" {
|
||||
return &testDnsConn{
|
||||
wantName: "s01.storage.devnev.",
|
||||
ipv4: []byte{10, 11, 12, 180},
|
||||
}, nil
|
||||
}
|
||||
panic("unexpected call")
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
require.True(t, resolved)
|
||||
}
|
||||
|
||||
func TestHostnameResolveIPv6(t *testing.T) {
|
||||
resolvedAddr := "[2001:db8:85a3:8d3:1319:8a2e:370:8195]:8080"
|
||||
ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195")
|
||||
resolved := false
|
||||
d, err := NewDialer(Config{
|
||||
Subnets: []string{"2001:db8:85a3:8d3::/64"},
|
||||
InterfaceSource: testInterfacesV6,
|
||||
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if resolvedAddr == address {
|
||||
resolved = true
|
||||
return &testConnStub{}, nil
|
||||
}
|
||||
if network == "udp" {
|
||||
return &testDnsConn{
|
||||
wantName: "s01.storage.devnev.",
|
||||
ipv6: ipv6,
|
||||
}, nil
|
||||
}
|
||||
panic("unexpected call")
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
require.True(t, resolved)
|
||||
}
|
||||
|
||||
func testInterfacesV4() ([]Interface, error) {
|
||||
return []Interface{
|
||||
&testInterface{
|
||||
name: "data1",
|
||||
addrs: []net.Addr{
|
||||
&testAddr{
|
||||
network: "tcp",
|
||||
str: "10.11.12.101/24",
|
||||
},
|
||||
},
|
||||
},
|
||||
&testInterface{
|
||||
name: "data2",
|
||||
addrs: []net.Addr{
|
||||
&testAddr{
|
||||
network: "tcp",
|
||||
str: "10.11.12.102/24",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func testInterfacesV6() ([]Interface, error) {
|
||||
return []Interface{
|
||||
&testInterface{
|
||||
name: "data1",
|
||||
addrs: []net.Addr{
|
||||
&testAddr{
|
||||
network: "tcp",
|
||||
str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64",
|
||||
},
|
||||
},
|
||||
},
|
||||
&testInterface{
|
||||
name: "data2",
|
||||
addrs: []net.Addr{
|
||||
&testAddr{
|
||||
network: "tcp",
|
||||
str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type testInterface struct {
|
||||
name string
|
||||
addrs []net.Addr
|
||||
}
|
||||
|
||||
func (i *testInterface) Name() string { return i.name }
|
||||
func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil }
|
||||
|
||||
type testAddr struct {
|
||||
network string
|
||||
str string
|
||||
}
|
||||
|
||||
func (a *testAddr) Network() string { return a.network }
|
||||
func (a *testAddr) String() string { return a.str }
|
||||
|
||||
type testDnsConn struct {
|
||||
wantName string
|
||||
ipv4 []byte
|
||||
ipv6 []byte
|
||||
reqID uint16
|
||||
requested bool
|
||||
secondRead bool
|
||||
questions []dnsmessage.Question
|
||||
}
|
||||
|
||||
// Close implements net.Conn.
|
||||
func (*testDnsConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr implements net.Conn.
|
||||
func (*testDnsConn) LocalAddr() net.Addr {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Read implements net.Conn.
|
||||
func (c *testDnsConn) Read(b []byte) (n int, err error) {
|
||||
if c.secondRead {
|
||||
data, err := c.data()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
copy(b, data)
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
data, err := c.data()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
l := len(data)
|
||||
b[0] = byte(l >> 8)
|
||||
b[1] = byte(l)
|
||||
c.secondRead = true
|
||||
return 2, nil
|
||||
}
|
||||
|
||||
func (c *testDnsConn) data() ([]byte, error) {
|
||||
if !c.requested {
|
||||
builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{RCode: dnsmessage.RCodeNameError})
|
||||
res, err := builder.Finish()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
msg := dnsmessage.Message{
|
||||
Header: dnsmessage.Header{Response: true, Authoritative: true, RCode: dnsmessage.RCodeSuccess, ID: c.reqID},
|
||||
Questions: c.questions,
|
||||
Answers: []dnsmessage.Resource{},
|
||||
}
|
||||
c.appendIPv4(&msg)
|
||||
c.appendIPV6(&msg)
|
||||
return msg.Pack()
|
||||
}
|
||||
|
||||
func (c *testDnsConn) appendIPv4(msg *dnsmessage.Message) {
|
||||
if c.ipv4 != nil {
|
||||
msg.Answers = append(msg.Answers, dnsmessage.Resource{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: dnsmessage.MustNewName(c.wantName),
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
},
|
||||
Body: &dnsmessage.AResource{A: [4]byte(c.ipv4)},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *testDnsConn) appendIPV6(msg *dnsmessage.Message) {
|
||||
if c.ipv6 != nil {
|
||||
msg.Answers = append(msg.Answers, dnsmessage.Resource{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: dnsmessage.MustNewName(c.wantName),
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
},
|
||||
Body: &dnsmessage.AAAAResource{AAAA: [16]byte(c.ipv6)},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RemoteAddr implements net.Conn.
|
||||
func (*testDnsConn) RemoteAddr() net.Addr {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// SetDeadline implements net.Conn.
|
||||
func (*testDnsConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements net.Conn.
|
||||
func (*testDnsConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements net.Conn.
|
||||
func (*testDnsConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write implements net.Conn.
|
||||
func (c *testDnsConn) Write(b []byte) (n int, err error) {
|
||||
var p dnsmessage.Parser
|
||||
var h dnsmessage.Header
|
||||
if h, err = p.Start(b[2:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
c.questions, err = p.AllQuestions()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for _, q := range c.questions {
|
||||
qStr := q.Name.String()
|
||||
if qStr != c.wantName {
|
||||
continue
|
||||
}
|
||||
|
||||
c.requested = true
|
||||
c.reqID = h.ID
|
||||
if err := p.SkipAllQuestions(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
type testConnStub struct{}
|
||||
|
||||
// Close implements net.Conn.
|
||||
func (*testConnStub) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr implements net.Conn.
|
||||
func (*testConnStub) LocalAddr() net.Addr {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Read implements net.Conn.
|
||||
func (*testConnStub) Read(b []byte) (n int, err error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// RemoteAddr implements net.Conn.
|
||||
func (*testConnStub) RemoteAddr() net.Addr {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// SetDeadline implements net.Conn.
|
||||
func (*testConnStub) SetDeadline(t time.Time) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// SetReadDeadline implements net.Conn.
|
||||
func (*testConnStub) SetReadDeadline(t time.Time) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements net.Conn.
|
||||
func (*testConnStub) SetWriteDeadline(t time.Time) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Write implements net.Conn.
|
||||
func (*testConnStub) Write(b []byte) (n int, err error) {
|
||||
panic("unimplemented")
|
||||
}
|
|
@ -1,3 +1,5 @@
|
|||
//go:build integration
|
||||
|
||||
package multinet
|
||||
|
||||
import (
|
3
go.mod
3
go.mod
|
@ -6,7 +6,8 @@ require (
|
|||
github.com/stretchr/testify v1.8.4
|
||||
github.com/vishvananda/netlink v1.1.0
|
||||
github.com/vishvananda/netns v0.0.4
|
||||
golang.org/x/sys v0.2.0
|
||||
golang.org/x/net v0.17.0
|
||||
golang.org/x/sys v0.13.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
|
6
go.sum
6
go.sum
|
@ -9,9 +9,11 @@ github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYp
|
|||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
//go:build integration
|
||||
|
||||
package multinet
|
||||
|
||||
import (
|
||||
|
@ -21,8 +23,13 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
|
||||
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d, err := NewDialer(Config{
|
||||
Subnets: []string{"1.2.30.0/23"},
|
||||
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -30,12 +37,6 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
require.NoError(t, w.Start())
|
||||
t.Cleanup(w.Stop)
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
checkDialAddr(t, d, result, addr1)
|
||||
checkDialAddr(t, d, result, addr1)
|
||||
|
||||
|
@ -62,8 +63,13 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
|
||||
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d, err := NewDialer(Config{
|
||||
Subnets: []string{"1.2.30.0/23"},
|
||||
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -71,12 +77,6 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
require.NoError(t, w.Start())
|
||||
t.Cleanup(w.Stop)
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
checkDialAddr(t, d, result, addr1)
|
||||
checkDialAddr(t, d, result, addr1)
|
||||
|
||||
|
@ -106,9 +106,14 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
|
||||
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d, err := NewDialer(Config{
|
||||
Subnets: []string{"1.2.30.0/23"},
|
||||
Balancer: BalancerTypeRoundRobin,
|
||||
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -116,12 +121,6 @@ func Test_NetlinkWatcher(t *testing.T) {
|
|||
require.NoError(t, w.Start())
|
||||
t.Cleanup(w.Stop)
|
||||
|
||||
result := make(chan net.Addr, 1)
|
||||
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
|
||||
result <- d.LocalAddr
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
checkDialAddr(t, d, result, addr2)
|
||||
checkDialAddr(t, d, result, addr1)
|
||||
checkDialAddr(t, d, result, addr2)
|
33
interface.go
Normal file
33
interface.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package multinet
|
||||
|
||||
import "net"
|
||||
|
||||
// Interface provides information about net.Interface.
|
||||
type Interface interface {
|
||||
Name() string
|
||||
Addrs() ([]net.Addr, error)
|
||||
}
|
||||
|
||||
type netInterface struct {
|
||||
iface net.Interface
|
||||
}
|
||||
|
||||
func (i *netInterface) Name() string {
|
||||
return i.iface.Name
|
||||
}
|
||||
|
||||
func (i *netInterface) Addrs() ([]net.Addr, error) {
|
||||
return i.iface.Addrs()
|
||||
}
|
||||
|
||||
func systemInterfaces() ([]Interface, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result []Interface
|
||||
for _, iface := range ifaces {
|
||||
result = append(result, &netInterface{iface: iface})
|
||||
}
|
||||
return result, nil
|
||||
}
|
Loading…
Add table
Reference in a new issue
Do we have usecase for this besides using in tests?
For example tracing, logging, metrics.
It is called once during construction, so I doubt it is useful, but ok