Support hostnames #6

Merged
fyrchik merged 1 commit from dstepanov-yadro/multinet:feat/host_names into master 2024-09-04 19:51:22 +00:00
7 changed files with 387 additions and 29 deletions

View file

@ -1,3 +1,6 @@
integration-test:
# TODO figure out needed capabilities
sudo go test -count=1 -v ./... -tags=integration
test:
go test -count=1 -v ./...

View file

@ -41,8 +41,8 @@ type dialer struct {
restrict bool
// Algorithm to which picks source address for each ip.
balancer balancer
// Hook used in tests, overrides `net.Dialer.DialContext()`
testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
// Overrides `net.Dialer.DialContext()` if specified.
customDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
// Hostname resolver
resolver net.Resolver
// See Config.FallbackDelay description.
@ -83,16 +83,28 @@ type Config struct {
// If zero, a default delay of 300ms is used.
// A negative value disables Fast Fallback support.
FallbackDelay time.Duration
// InterfaceSource is custom `Interface`` source.
// If not specified, default implementation is used (`net.Interfaces()``).
InterfaceSource func() ([]Interface, error)
fyrchik marked this conversation as resolved Outdated

Do we have usecase for this besides using in tests?

Do we have usecase for this besides using in tests?

For example tracing, logging, metrics.

For example tracing, logging, metrics.

It is called once during construction, so I doubt it is useful, but ok

It is called once during construction, so I doubt it is useful, but ok
// DialContext is custom DialContext function.
// If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`).
DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
}
// NewDialer ...
func NewDialer(c Config) (Multidialer, error) {
ifaces, err := net.Interfaces()
var ifaces []Interface
var err error
if c.InterfaceSource != nil {
ifaces, err = c.InterfaceSource()
} else {
ifaces, err = systemInterfaces()
}
if err != nil {
return nil, err
}
sort.Slice(ifaces, func(i, j int) bool {
return ifaces[i].Name < ifaces[j].Name
return ifaces[i].Name() < ifaces[j].Name()
})
var sources []iface
@ -128,16 +140,20 @@ func NewDialer(c Config) (Multidialer, error) {
if d.fallbackDelay == 0 {
d.fallbackDelay = defaultFallbackDelay
}
d.dialer = c.Dialer
if c.DialContext != nil {
d.customDialContext = c.DialContext
}
return &d, nil
}
type iface struct {
info net.Interface
name string
addrs []netip.Prefix
}
func processIface(info net.Interface) (iface, error) {
func processIface(info Interface) (iface, error) {
ips, err := info.Addrs()
if err != nil {
return iface{}, err
@ -152,7 +168,7 @@ func processIface(info net.Interface) (iface, error) {
addrs = append(addrs, p)
}
return iface{info: info, addrs: addrs}, nil
return iface{name: info.Name(), addrs: addrs}, nil
}
func processSubnet(subnet string, sources []iface) (Subnet, error) {
@ -167,7 +183,7 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) {
src := source.addrs[i].Addr()
if s.Contains(src) {
ifs = append(ifs, Source{
Name: source.info.Name,
Name: source.name,
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
})
}
@ -361,11 +377,11 @@ func (d *dialer) dialAddr(ctx context.Context, network, address string, addr net
if d.restrict {
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
}
return d.dialer.DialContext(ctx, network, address)
return d.dialContext(&d.dialer, ctx, network, address)
}
func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
fyrchik marked this conversation as resolved Outdated

What was wrong with testHook? It is in spirit of the related stdlib pieces b5f87b5407/src/net/dial.go (L423)

What was wrong with `testHook`? It is in spirit of the related stdlib pieces https://github.com/golang/go/blob/b5f87b5407916c4049a3158cc944cebfd7a883a9/src/net/dial.go#L423

Nothing wrong. But I guess it could be used not only for tests. For example tracing, logging, metrics.

Nothing wrong. But I guess it could be used not only for tests. For example tracing, logging, metrics.
if h := d.testHookDialContext; h != nil {
if h := d.customDialContext; h != nil {
return h(nd, ctx, network, address)
}
return nd.DialContext(ctx, network, address)

308
dialer_hostname_test.go Normal file
View file

@ -0,0 +1,308 @@
package multinet
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/net/dns/dnsmessage"
)
func TestHostnameResolveIPv4(t *testing.T) {
resolvedAddr := "10.11.12.180:8080"
resolved := false
d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24"},
InterfaceSource: testInterfacesV4,
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
return &testConnStub{}, nil
}
if network == "udp" {
return &testDnsConn{
wantName: "s01.storage.devnev.",
ipv4: []byte{10, 11, 12, 180},
}, nil
}
panic("unexpected call")
},
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
require.NoError(t, err)
require.NoError(t, conn.Close())
require.True(t, resolved)
}
func TestHostnameResolveIPv6(t *testing.T) {
resolvedAddr := "[2001:db8:85a3:8d3:1319:8a2e:370:8195]:8080"
ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195")
resolved := false
d, err := NewDialer(Config{
Subnets: []string{"2001:db8:85a3:8d3::/64"},
InterfaceSource: testInterfacesV6,
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
return &testConnStub{}, nil
}
if network == "udp" {
return &testDnsConn{
wantName: "s01.storage.devnev.",
ipv6: ipv6,
}, nil
}
panic("unexpected call")
},
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
require.NoError(t, err)
require.NoError(t, conn.Close())
require.True(t, resolved)
}
func testInterfacesV4() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.101/24",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.102/24",
},
},
},
}, nil
}
func testInterfacesV6() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64",
},
},
},
}, nil
}
type testInterface struct {
name string
addrs []net.Addr
}
func (i *testInterface) Name() string { return i.name }
func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil }
type testAddr struct {
network string
str string
}
func (a *testAddr) Network() string { return a.network }
func (a *testAddr) String() string { return a.str }
type testDnsConn struct {
wantName string
ipv4 []byte
ipv6 []byte
reqID uint16
requested bool
secondRead bool
questions []dnsmessage.Question
}
// Close implements net.Conn.
func (*testDnsConn) Close() error {
return nil
}
// LocalAddr implements net.Conn.
func (*testDnsConn) LocalAddr() net.Addr {
panic("unimplemented")
}
// Read implements net.Conn.
func (c *testDnsConn) Read(b []byte) (n int, err error) {
if c.secondRead {
data, err := c.data()
if err != nil {
return 0, err
}
copy(b, data)
return len(data), nil
}
data, err := c.data()
if err != nil {
return 0, err
}
l := len(data)
b[0] = byte(l >> 8)
b[1] = byte(l)
c.secondRead = true
return 2, nil
}
func (c *testDnsConn) data() ([]byte, error) {
if !c.requested {
builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{RCode: dnsmessage.RCodeNameError})
res, err := builder.Finish()
if err != nil {
return nil, err
}
return res, nil
}
msg := dnsmessage.Message{
Header: dnsmessage.Header{Response: true, Authoritative: true, RCode: dnsmessage.RCodeSuccess, ID: c.reqID},
Questions: c.questions,
Answers: []dnsmessage.Resource{},
}
c.appendIPv4(&msg)
c.appendIPV6(&msg)
return msg.Pack()
}
func (c *testDnsConn) appendIPv4(msg *dnsmessage.Message) {
if c.ipv4 != nil {
msg.Answers = append(msg.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(c.wantName),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AResource{A: [4]byte(c.ipv4)},
})
}
}
func (c *testDnsConn) appendIPV6(msg *dnsmessage.Message) {
if c.ipv6 != nil {
msg.Answers = append(msg.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(c.wantName),
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AAAAResource{AAAA: [16]byte(c.ipv6)},
})
}
}
// RemoteAddr implements net.Conn.
func (*testDnsConn) RemoteAddr() net.Addr {
panic("unimplemented")
}
// SetDeadline implements net.Conn.
func (*testDnsConn) SetDeadline(t time.Time) error {
return nil
}
// SetReadDeadline implements net.Conn.
func (*testDnsConn) SetReadDeadline(t time.Time) error {
return nil
}
// SetWriteDeadline implements net.Conn.
func (*testDnsConn) SetWriteDeadline(t time.Time) error {
return nil
}
// Write implements net.Conn.
func (c *testDnsConn) Write(b []byte) (n int, err error) {
var p dnsmessage.Parser
var h dnsmessage.Header
if h, err = p.Start(b[2:]); err != nil {
return 0, err
}
c.questions, err = p.AllQuestions()
if err != nil {
return 0, err
}
for _, q := range c.questions {
qStr := q.Name.String()
if qStr != c.wantName {
continue
}
c.requested = true
c.reqID = h.ID
if err := p.SkipAllQuestions(); err != nil {
return 0, err
}
break
}
return len(b), nil
}
type testConnStub struct{}
// Close implements net.Conn.
func (*testConnStub) Close() error {
return nil
}
// LocalAddr implements net.Conn.
func (*testConnStub) LocalAddr() net.Addr {
panic("unimplemented")
}
// Read implements net.Conn.
func (*testConnStub) Read(b []byte) (n int, err error) {
panic("unimplemented")
}
// RemoteAddr implements net.Conn.
func (*testConnStub) RemoteAddr() net.Addr {
panic("unimplemented")
}
// SetDeadline implements net.Conn.
func (*testConnStub) SetDeadline(t time.Time) error {
panic("unimplemented")
}
// SetReadDeadline implements net.Conn.
func (*testConnStub) SetReadDeadline(t time.Time) error {
panic("unimplemented")
}
// SetWriteDeadline implements net.Conn.
func (*testConnStub) SetWriteDeadline(t time.Time) error {
panic("unimplemented")
}
// Write implements net.Conn.
func (*testConnStub) Write(b []byte) (n int, err error) {
panic("unimplemented")
}

3
go.mod
View file

@ -6,7 +6,8 @@ require (
github.com/stretchr/testify v1.8.4
github.com/vishvananda/netlink v1.1.0
github.com/vishvananda/netns v0.0.4
golang.org/x/sys v0.2.0
golang.org/x/net v0.17.0
golang.org/x/sys v0.13.0
)
require (

BIN
go.sum

Binary file not shown.

View file

@ -23,8 +23,13 @@ func Test_NetlinkWatcher(t *testing.T) {
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
@ -32,12 +37,6 @@ func Test_NetlinkWatcher(t *testing.T) {
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
result := make(chan net.Addr, 1)
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
}
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
@ -64,8 +63,13 @@ func Test_NetlinkWatcher(t *testing.T) {
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
@ -73,12 +77,6 @@ func Test_NetlinkWatcher(t *testing.T) {
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
result := make(chan net.Addr, 1)
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
}
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
@ -108,9 +106,14 @@ func Test_NetlinkWatcher(t *testing.T) {
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
Balancer: BalancerTypeRoundRobin,
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
@ -118,12 +121,6 @@ func Test_NetlinkWatcher(t *testing.T) {
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
result := make(chan net.Addr, 1)
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
}
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr2)

33
interface.go Normal file
View file

@ -0,0 +1,33 @@
package multinet
import "net"
// Interface provides information about net.Interface.
type Interface interface {
Name() string
Addrs() ([]net.Addr, error)
}
type netInterface struct {
iface net.Interface
}
func (i *netInterface) Name() string {
return i.iface.Name
}
func (i *netInterface) Addrs() ([]net.Addr, error) {
return i.iface.Addrs()
}
func systemInterfaces() ([]Interface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var result []Interface
for _, iface := range ifaces {
result = append(result, &netInterface{iface: iface})
}
return result, nil
}