forked from TrueCloudLab/multinet
308 lines
6.6 KiB
Go
308 lines
6.6 KiB
Go
package multinet
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
)
|
|
|
|
func TestHostnameResolveIPv4(t *testing.T) {
|
|
resolvedAddr := "10.11.12.180:8080"
|
|
resolved := false
|
|
d, err := NewDialer(Config{
|
|
Subnets: []string{"10.11.12.0/24"},
|
|
InterfaceSource: testInterfacesV4,
|
|
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
|
|
if resolvedAddr == address {
|
|
resolved = true
|
|
return &testConnStub{}, nil
|
|
}
|
|
if network == "udp" {
|
|
return &testDnsConn{
|
|
wantName: "s01.storage.devnev.",
|
|
ipv4: []byte{10, 11, 12, 180},
|
|
}, nil
|
|
}
|
|
panic("unexpected call")
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
|
|
require.NoError(t, err)
|
|
require.NoError(t, conn.Close())
|
|
require.True(t, resolved)
|
|
}
|
|
|
|
func TestHostnameResolveIPv6(t *testing.T) {
|
|
resolvedAddr := "[2001:db8:85a3:8d3:1319:8a2e:370:8195]:8080"
|
|
ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195")
|
|
resolved := false
|
|
d, err := NewDialer(Config{
|
|
Subnets: []string{"2001:db8:85a3:8d3::/64"},
|
|
InterfaceSource: testInterfacesV6,
|
|
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
|
|
if resolvedAddr == address {
|
|
resolved = true
|
|
return &testConnStub{}, nil
|
|
}
|
|
if network == "udp" {
|
|
return &testDnsConn{
|
|
wantName: "s01.storage.devnev.",
|
|
ipv6: ipv6,
|
|
}, nil
|
|
}
|
|
panic("unexpected call")
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
|
|
require.NoError(t, err)
|
|
require.NoError(t, conn.Close())
|
|
require.True(t, resolved)
|
|
}
|
|
|
|
func testInterfacesV4() ([]Interface, error) {
|
|
return []Interface{
|
|
&testInterface{
|
|
name: "data1",
|
|
addrs: []net.Addr{
|
|
&testAddr{
|
|
network: "tcp",
|
|
str: "10.11.12.101/24",
|
|
},
|
|
},
|
|
},
|
|
&testInterface{
|
|
name: "data2",
|
|
addrs: []net.Addr{
|
|
&testAddr{
|
|
network: "tcp",
|
|
str: "10.11.12.102/24",
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func testInterfacesV6() ([]Interface, error) {
|
|
return []Interface{
|
|
&testInterface{
|
|
name: "data1",
|
|
addrs: []net.Addr{
|
|
&testAddr{
|
|
network: "tcp",
|
|
str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64",
|
|
},
|
|
},
|
|
},
|
|
&testInterface{
|
|
name: "data2",
|
|
addrs: []net.Addr{
|
|
&testAddr{
|
|
network: "tcp",
|
|
str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64",
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
type testInterface struct {
|
|
name string
|
|
addrs []net.Addr
|
|
}
|
|
|
|
func (i *testInterface) Name() string { return i.name }
|
|
func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil }
|
|
|
|
type testAddr struct {
|
|
network string
|
|
str string
|
|
}
|
|
|
|
func (a *testAddr) Network() string { return a.network }
|
|
func (a *testAddr) String() string { return a.str }
|
|
|
|
type testDnsConn struct {
|
|
wantName string
|
|
ipv4 []byte
|
|
ipv6 []byte
|
|
reqID uint16
|
|
requested bool
|
|
secondRead bool
|
|
questions []dnsmessage.Question
|
|
}
|
|
|
|
// Close implements net.Conn.
|
|
func (*testDnsConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
// LocalAddr implements net.Conn.
|
|
func (*testDnsConn) LocalAddr() net.Addr {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// Read implements net.Conn.
|
|
func (c *testDnsConn) Read(b []byte) (n int, err error) {
|
|
if c.secondRead {
|
|
data, err := c.data()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
copy(b, data)
|
|
return len(data), nil
|
|
}
|
|
|
|
data, err := c.data()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
l := len(data)
|
|
b[0] = byte(l >> 8)
|
|
b[1] = byte(l)
|
|
c.secondRead = true
|
|
return 2, nil
|
|
}
|
|
|
|
func (c *testDnsConn) data() ([]byte, error) {
|
|
if !c.requested {
|
|
builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{RCode: dnsmessage.RCodeNameError})
|
|
res, err := builder.Finish()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
msg := dnsmessage.Message{
|
|
Header: dnsmessage.Header{Response: true, Authoritative: true, RCode: dnsmessage.RCodeSuccess, ID: c.reqID},
|
|
Questions: c.questions,
|
|
Answers: []dnsmessage.Resource{},
|
|
}
|
|
c.appendIPv4(&msg)
|
|
c.appendIPV6(&msg)
|
|
return msg.Pack()
|
|
}
|
|
|
|
func (c *testDnsConn) appendIPv4(msg *dnsmessage.Message) {
|
|
if c.ipv4 != nil {
|
|
msg.Answers = append(msg.Answers, dnsmessage.Resource{
|
|
Header: dnsmessage.ResourceHeader{
|
|
Name: dnsmessage.MustNewName(c.wantName),
|
|
Type: dnsmessage.TypeA,
|
|
Class: dnsmessage.ClassINET,
|
|
},
|
|
Body: &dnsmessage.AResource{A: [4]byte(c.ipv4)},
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *testDnsConn) appendIPV6(msg *dnsmessage.Message) {
|
|
if c.ipv6 != nil {
|
|
msg.Answers = append(msg.Answers, dnsmessage.Resource{
|
|
Header: dnsmessage.ResourceHeader{
|
|
Name: dnsmessage.MustNewName(c.wantName),
|
|
Type: dnsmessage.TypeAAAA,
|
|
Class: dnsmessage.ClassINET,
|
|
},
|
|
Body: &dnsmessage.AAAAResource{AAAA: [16]byte(c.ipv6)},
|
|
})
|
|
}
|
|
}
|
|
|
|
// RemoteAddr implements net.Conn.
|
|
func (*testDnsConn) RemoteAddr() net.Addr {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// SetDeadline implements net.Conn.
|
|
func (*testDnsConn) SetDeadline(t time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
// SetReadDeadline implements net.Conn.
|
|
func (*testDnsConn) SetReadDeadline(t time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
// SetWriteDeadline implements net.Conn.
|
|
func (*testDnsConn) SetWriteDeadline(t time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
// Write implements net.Conn.
|
|
func (c *testDnsConn) Write(b []byte) (n int, err error) {
|
|
var p dnsmessage.Parser
|
|
var h dnsmessage.Header
|
|
if h, err = p.Start(b[2:]); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
c.questions, err = p.AllQuestions()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
for _, q := range c.questions {
|
|
qStr := q.Name.String()
|
|
if qStr != c.wantName {
|
|
continue
|
|
}
|
|
|
|
c.requested = true
|
|
c.reqID = h.ID
|
|
if err := p.SkipAllQuestions(); err != nil {
|
|
return 0, err
|
|
}
|
|
break
|
|
}
|
|
|
|
return len(b), nil
|
|
}
|
|
|
|
type testConnStub struct{}
|
|
|
|
// Close implements net.Conn.
|
|
func (*testConnStub) Close() error {
|
|
return nil
|
|
}
|
|
|
|
// LocalAddr implements net.Conn.
|
|
func (*testConnStub) LocalAddr() net.Addr {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// Read implements net.Conn.
|
|
func (*testConnStub) Read(b []byte) (n int, err error) {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// RemoteAddr implements net.Conn.
|
|
func (*testConnStub) RemoteAddr() net.Addr {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// SetDeadline implements net.Conn.
|
|
func (*testConnStub) SetDeadline(t time.Time) error {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// SetReadDeadline implements net.Conn.
|
|
func (*testConnStub) SetReadDeadline(t time.Time) error {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// SetWriteDeadline implements net.Conn.
|
|
func (*testConnStub) SetWriteDeadline(t time.Time) error {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// Write implements net.Conn.
|
|
func (*testConnStub) Write(b []byte) (n int, err error) {
|
|
panic("unimplemented")
|
|
}
|