package multinet import ( "context" "net" "testing" "time" "github.com/stretchr/testify/require" "golang.org/x/net/dns/dnsmessage" ) func TestHostnameResolveIPv4(t *testing.T) { resolvedAddr := "10.11.12.180:8080" resolved := false d, err := NewDialer(Config{ Subnets: []string{"10.11.12.0/24"}, InterfaceSource: testInterfacesV4, DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) { if resolvedAddr == address { resolved = true return &testConnStub{}, nil } if network == "udp" { return &testDnsConn{ wantName: "s01.storage.devnev.", ipv4: []byte{10, 11, 12, 180}, }, nil } panic("unexpected call") }, }) require.NoError(t, err) conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080") require.NoError(t, err) require.NoError(t, conn.Close()) require.True(t, resolved) } func TestHostnameResolveIPv6(t *testing.T) { resolvedAddr := "[2001:db8:85a3:8d3:1319:8a2e:370:8195]:8080" ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195") resolved := false d, err := NewDialer(Config{ Subnets: []string{"2001:db8:85a3:8d3::/64"}, InterfaceSource: testInterfacesV6, DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) { if resolvedAddr == address { resolved = true return &testConnStub{}, nil } if network == "udp" { return &testDnsConn{ wantName: "s01.storage.devnev.", ipv6: ipv6, }, nil } panic("unexpected call") }, }) require.NoError(t, err) conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080") require.NoError(t, err) require.NoError(t, conn.Close()) require.True(t, resolved) } func testInterfacesV4() ([]Interface, error) { return []Interface{ &testInterface{ name: "data1", addrs: []net.Addr{ &testAddr{ network: "tcp", str: "10.11.12.101/24", }, }, }, &testInterface{ name: "data2", addrs: []net.Addr{ &testAddr{ network: "tcp", str: "10.11.12.102/24", }, }, }, }, nil } func testInterfacesV6() ([]Interface, error) { return []Interface{ &testInterface{ name: "data1", addrs: []net.Addr{ &testAddr{ network: "tcp", str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64", }, }, }, &testInterface{ name: "data2", addrs: []net.Addr{ &testAddr{ network: "tcp", str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64", }, }, }, }, nil } type testInterface struct { name string addrs []net.Addr down bool } func (i *testInterface) Name() string { return i.name } func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil } func (i *testInterface) Down() bool { return i.down } type testAddr struct { network string str string } func (a *testAddr) Network() string { return a.network } func (a *testAddr) String() string { return a.str } type testDnsConn struct { wantName string ipv4 []byte ipv6 []byte reqID uint16 requested bool secondRead bool questions []dnsmessage.Question } // Close implements net.Conn. func (*testDnsConn) Close() error { return nil } // LocalAddr implements net.Conn. func (*testDnsConn) LocalAddr() net.Addr { panic("unimplemented") } // Read implements net.Conn. func (c *testDnsConn) Read(b []byte) (n int, err error) { if c.secondRead { data, err := c.data() if err != nil { return 0, err } copy(b, data) return len(data), nil } data, err := c.data() if err != nil { return 0, err } l := len(data) b[0] = byte(l >> 8) b[1] = byte(l) c.secondRead = true return 2, nil } func (c *testDnsConn) data() ([]byte, error) { if !c.requested { builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{RCode: dnsmessage.RCodeNameError}) res, err := builder.Finish() if err != nil { return nil, err } return res, nil } msg := dnsmessage.Message{ Header: dnsmessage.Header{Response: true, Authoritative: true, RCode: dnsmessage.RCodeSuccess, ID: c.reqID}, Questions: c.questions, Answers: []dnsmessage.Resource{}, } c.appendIPv4(&msg) c.appendIPV6(&msg) return msg.Pack() } func (c *testDnsConn) appendIPv4(msg *dnsmessage.Message) { if c.ipv4 != nil { msg.Answers = append(msg.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: dnsmessage.MustNewName(c.wantName), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, }, Body: &dnsmessage.AResource{A: [4]byte(c.ipv4)}, }) } } func (c *testDnsConn) appendIPV6(msg *dnsmessage.Message) { if c.ipv6 != nil { msg.Answers = append(msg.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: dnsmessage.MustNewName(c.wantName), Type: dnsmessage.TypeAAAA, Class: dnsmessage.ClassINET, }, Body: &dnsmessage.AAAAResource{AAAA: [16]byte(c.ipv6)}, }) } } // RemoteAddr implements net.Conn. func (*testDnsConn) RemoteAddr() net.Addr { panic("unimplemented") } // SetDeadline implements net.Conn. func (*testDnsConn) SetDeadline(t time.Time) error { return nil } // SetReadDeadline implements net.Conn. func (*testDnsConn) SetReadDeadline(t time.Time) error { return nil } // SetWriteDeadline implements net.Conn. func (*testDnsConn) SetWriteDeadline(t time.Time) error { return nil } // Write implements net.Conn. func (c *testDnsConn) Write(b []byte) (n int, err error) { var p dnsmessage.Parser var h dnsmessage.Header if h, err = p.Start(b[2:]); err != nil { return 0, err } c.questions, err = p.AllQuestions() if err != nil { return 0, err } for _, q := range c.questions { qStr := q.Name.String() if qStr != c.wantName { continue } c.requested = true c.reqID = h.ID if err := p.SkipAllQuestions(); err != nil { return 0, err } break } return len(b), nil } type testConnStub struct{} // Close implements net.Conn. func (*testConnStub) Close() error { return nil } // LocalAddr implements net.Conn. func (*testConnStub) LocalAddr() net.Addr { panic("unimplemented") } // Read implements net.Conn. func (*testConnStub) Read(b []byte) (n int, err error) { panic("unimplemented") } // RemoteAddr implements net.Conn. func (*testConnStub) RemoteAddr() net.Addr { panic("unimplemented") } // SetDeadline implements net.Conn. func (*testConnStub) SetDeadline(t time.Time) error { panic("unimplemented") } // SetReadDeadline implements net.Conn. func (*testConnStub) SetReadDeadline(t time.Time) error { panic("unimplemented") } // SetWriteDeadline implements net.Conn. func (*testConnStub) SetWriteDeadline(t time.Time) error { panic("unimplemented") } // Write implements net.Conn. func (*testConnStub) Write(b []byte) (n int, err error) { panic("unimplemented") }