[#4] dialer: Add hostname resolve tests

Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
This commit is contained in:
Dmitrii Stepanov 2023-10-24 15:43:08 +03:00
parent 37b0350e95
commit 5298ec4295
7 changed files with 387 additions and 29 deletions

View file

@ -1,3 +1,6 @@
integration-test:
# TODO figure out needed capabilities
sudo go test -count=1 -v ./... -tags=integration
test:
go test -count=1 -v ./...

View file

@ -41,8 +41,8 @@ type dialer struct {
restrict bool
// Algorithm to which picks source address for each ip.
balancer balancer
// Hook used in tests, overrides `net.Dialer.DialContext()`
testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
// Overrides `net.Dialer.DialContext()` if specified.
customDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
// Hostname resolver
resolver net.Resolver
// See Config.FallbackDelay description.
@ -83,16 +83,28 @@ type Config struct {
// If zero, a default delay of 300ms is used.
// A negative value disables Fast Fallback support.
FallbackDelay time.Duration
// InterfaceSource is custom `Interface`` source.
// If not specified, default implementation is used (`net.Interfaces()``).
InterfaceSource func() ([]Interface, error)
// DialContext is custom DialContext function.
// If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`).
DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
}
// NewDialer ...
func NewDialer(c Config) (Multidialer, error) {
ifaces, err := net.Interfaces()
var ifaces []Interface
var err error
if c.InterfaceSource != nil {
ifaces, err = c.InterfaceSource()
} else {
ifaces, err = systemInterfaces()
}
if err != nil {
return nil, err
}
sort.Slice(ifaces, func(i, j int) bool {
return ifaces[i].Name < ifaces[j].Name
return ifaces[i].Name() < ifaces[j].Name()
})
var sources []iface
@ -128,16 +140,20 @@ func NewDialer(c Config) (Multidialer, error) {
if d.fallbackDelay == 0 {
d.fallbackDelay = defaultFallbackDelay
}
d.dialer = c.Dialer
if c.DialContext != nil {
d.customDialContext = c.DialContext
}
return &d, nil
}
type iface struct {
info net.Interface
name string
addrs []netip.Prefix
}
func processIface(info net.Interface) (iface, error) {
func processIface(info Interface) (iface, error) {
ips, err := info.Addrs()
if err != nil {
return iface{}, err
@ -152,7 +168,7 @@ func processIface(info net.Interface) (iface, error) {
addrs = append(addrs, p)
}
return iface{info: info, addrs: addrs}, nil
return iface{name: info.Name(), addrs: addrs}, nil
}
func processSubnet(subnet string, sources []iface) (Subnet, error) {
@ -167,7 +183,7 @@ func processSubnet(subnet string, sources []iface) (Subnet, error) {
src := source.addrs[i].Addr()
if s.Contains(src) {
ifs = append(ifs, Source{
Name: source.info.Name,
Name: source.name,
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
})
}
@ -361,11 +377,11 @@ func (d *dialer) dialAddr(ctx context.Context, network, address string, addr net
if d.restrict {
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
}
return d.dialer.DialContext(ctx, network, address)
return d.dialContext(&d.dialer, ctx, network, address)
}
func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if h := d.testHookDialContext; h != nil {
if h := d.customDialContext; h != nil {
return h(nd, ctx, network, address)
}
return nd.DialContext(ctx, network, address)

308
dialer_hostname_test.go Normal file
View file

@ -0,0 +1,308 @@
package multinet
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/net/dns/dnsmessage"
)
func TestHostnameResolveIPv4(t *testing.T) {
resolvedAddr := "10.11.12.180:8080"
resolved := false
d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24"},
InterfaceSource: testInterfacesV4,
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
return &testConnStub{}, nil
}
if network == "udp" {
return &testDnsConn{
wantName: "s01.storage.devnev.",
ipv4: []byte{10, 11, 12, 180},
}, nil
}
panic("unexpected call")
},
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
require.NoError(t, err)
require.NoError(t, conn.Close())
require.True(t, resolved)
}
func TestHostnameResolveIPv6(t *testing.T) {
resolvedAddr := "[2001:db8:85a3:8d3:1319:8a2e:370:8195]:8080"
ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195")
resolved := false
d, err := NewDialer(Config{
Subnets: []string{"2001:db8:85a3:8d3::/64"},
InterfaceSource: testInterfacesV6,
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
return &testConnStub{}, nil
}
if network == "udp" {
return &testDnsConn{
wantName: "s01.storage.devnev.",
ipv6: ipv6,
}, nil
}
panic("unexpected call")
},
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
require.NoError(t, err)
require.NoError(t, conn.Close())
require.True(t, resolved)
}
func testInterfacesV4() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.101/24",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.102/24",
},
},
},
}, nil
}
func testInterfacesV6() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64",
},
},
},
}, nil
}
type testInterface struct {
name string
addrs []net.Addr
}
func (i *testInterface) Name() string { return i.name }
func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil }
type testAddr struct {
network string
str string
}
func (a *testAddr) Network() string { return a.network }
func (a *testAddr) String() string { return a.str }
type testDnsConn struct {
wantName string
ipv4 []byte
ipv6 []byte
reqID uint16
requested bool
secondRead bool
questions []dnsmessage.Question
}
// Close implements net.Conn.
func (*testDnsConn) Close() error {
return nil
}
// LocalAddr implements net.Conn.
func (*testDnsConn) LocalAddr() net.Addr {
panic("unimplemented")
}
// Read implements net.Conn.
func (c *testDnsConn) Read(b []byte) (n int, err error) {
if c.secondRead {
data, err := c.data()
if err != nil {
return 0, err
}
copy(b, data)
return len(data), nil
}
data, err := c.data()
if err != nil {
return 0, err
}
l := len(data)
b[0] = byte(l >> 8)
b[1] = byte(l)
c.secondRead = true
return 2, nil
}
func (c *testDnsConn) data() ([]byte, error) {
if !c.requested {
builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{RCode: dnsmessage.RCodeNameError})
res, err := builder.Finish()
if err != nil {
return nil, err
}
return res, nil
}
msg := dnsmessage.Message{
Header: dnsmessage.Header{Response: true, Authoritative: true, RCode: dnsmessage.RCodeSuccess, ID: c.reqID},
Questions: c.questions,
Answers: []dnsmessage.Resource{},
}
c.appendIPv4(&msg)
c.appendIPV6(&msg)
return msg.Pack()
}
func (c *testDnsConn) appendIPv4(msg *dnsmessage.Message) {
if c.ipv4 != nil {
msg.Answers = append(msg.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(c.wantName),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AResource{A: [4]byte(c.ipv4)},
})
}
}
func (c *testDnsConn) appendIPV6(msg *dnsmessage.Message) {
if c.ipv6 != nil {
msg.Answers = append(msg.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(c.wantName),
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AAAAResource{AAAA: [16]byte(c.ipv6)},
})
}
}
// RemoteAddr implements net.Conn.
func (*testDnsConn) RemoteAddr() net.Addr {
panic("unimplemented")
}
// SetDeadline implements net.Conn.
func (*testDnsConn) SetDeadline(t time.Time) error {
return nil
}
// SetReadDeadline implements net.Conn.
func (*testDnsConn) SetReadDeadline(t time.Time) error {
return nil
}
// SetWriteDeadline implements net.Conn.
func (*testDnsConn) SetWriteDeadline(t time.Time) error {
return nil
}
// Write implements net.Conn.
func (c *testDnsConn) Write(b []byte) (n int, err error) {
var p dnsmessage.Parser
var h dnsmessage.Header
if h, err = p.Start(b[2:]); err != nil {
return 0, err
}
c.questions, err = p.AllQuestions()
if err != nil {
return 0, err
}
for _, q := range c.questions {
qStr := q.Name.String()
if qStr != c.wantName {
continue
}
c.requested = true
c.reqID = h.ID
if err := p.SkipAllQuestions(); err != nil {
return 0, err
}
break
}
return len(b), nil
}
type testConnStub struct{}
// Close implements net.Conn.
func (*testConnStub) Close() error {
return nil
}
// LocalAddr implements net.Conn.
func (*testConnStub) LocalAddr() net.Addr {
panic("unimplemented")
}
// Read implements net.Conn.
func (*testConnStub) Read(b []byte) (n int, err error) {
panic("unimplemented")
}
// RemoteAddr implements net.Conn.
func (*testConnStub) RemoteAddr() net.Addr {
panic("unimplemented")
}
// SetDeadline implements net.Conn.
func (*testConnStub) SetDeadline(t time.Time) error {
panic("unimplemented")
}
// SetReadDeadline implements net.Conn.
func (*testConnStub) SetReadDeadline(t time.Time) error {
panic("unimplemented")
}
// SetWriteDeadline implements net.Conn.
func (*testConnStub) SetWriteDeadline(t time.Time) error {
panic("unimplemented")
}
// Write implements net.Conn.
func (*testConnStub) Write(b []byte) (n int, err error) {
panic("unimplemented")
}

3
go.mod
View file

@ -6,7 +6,8 @@ require (
github.com/stretchr/testify v1.8.4
github.com/vishvananda/netlink v1.1.0
github.com/vishvananda/netns v0.0.4
golang.org/x/sys v0.2.0
golang.org/x/net v0.17.0
golang.org/x/sys v0.13.0
)
require (

BIN
go.sum

Binary file not shown.

View file

@ -23,8 +23,13 @@ func Test_NetlinkWatcher(t *testing.T) {
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
@ -32,12 +37,6 @@ func Test_NetlinkWatcher(t *testing.T) {
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
result := make(chan net.Addr, 1)
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
}
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
@ -64,8 +63,13 @@ func Test_NetlinkWatcher(t *testing.T) {
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
@ -73,12 +77,6 @@ func Test_NetlinkWatcher(t *testing.T) {
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
result := make(chan net.Addr, 1)
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
}
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
@ -108,9 +106,14 @@ func Test_NetlinkWatcher(t *testing.T) {
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
Balancer: BalancerTypeRoundRobin,
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
@ -118,12 +121,6 @@ func Test_NetlinkWatcher(t *testing.T) {
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
result := make(chan net.Addr, 1)
d.(*dialer).testHookDialContext = func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
}
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr2)

33
interface.go Normal file
View file

@ -0,0 +1,33 @@
package multinet
import "net"
// Interface provides information about net.Interface.
type Interface interface {
Name() string
Addrs() ([]net.Addr, error)
}
type netInterface struct {
iface net.Interface
}
func (i *netInterface) Name() string {
return i.iface.Name
}
func (i *netInterface) Addrs() ([]net.Addr, error) {
return i.iface.Addrs()
}
func systemInterfaces() ([]Interface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var result []Interface
for _, iface := range ifaces {
result = append(result, &netInterface{iface: iface})
}
return result, nil
}