multinet/dialer_hostname_test.go
Dmitrii Stepanov 4a46c8c008
[#11] dialer: Use source IPs intead of interfaces
The use of network interfaces does not cover cases where it is necessary
to use network interfaces to access different subnets.

Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
2024-10-09 16:22:45 +03:00

261 lines
5.8 KiB
Go

package multinet
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/net/dns/dnsmessage"
)
func TestHostnameResolveIPv4(t *testing.T) {
resolvedAddr := "10.11.12.180:8080"
resolved := false
d, err := NewDialer(Config{
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("10.11.12.0/24"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("10.11.12.101"),
netip.MustParseAddr("10.11.12.102"),
},
},
},
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
return &testConnStub{}, nil
}
if network == "udp" {
return &testDnsConn{
wantName: "s01.storage.devnev.",
ipv4: []byte{10, 11, 12, 180},
}, nil
}
panic("unexpected call")
},
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
require.NoError(t, err)
require.NoError(t, conn.Close())
require.True(t, resolved)
}
func TestHostnameResolveIPv6(t *testing.T) {
resolvedAddr := "[2001:db8:85a3:8d3:1319:8a2e:370:8195]:8080"
ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195")
resolved := false
d, err := NewDialer(Config{
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("2001:db8:85a3:8d3::/64"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("2001:db8:85a3:8d3:1319:8a2e:370:7348"),
netip.MustParseAddr("2001:db8:85a3:8d3:1319:8a2e:370:8192"),
},
},
},
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
return &testConnStub{}, nil
}
if network == "udp" {
return &testDnsConn{
wantName: "s01.storage.devnev.",
ipv6: ipv6,
}, nil
}
panic("unexpected call")
},
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "s01.storage.devnev:8080")
require.NoError(t, err)
require.NoError(t, conn.Close())
require.True(t, resolved)
}
type testDnsConn struct {
wantName string
ipv4 []byte
ipv6 []byte
reqID uint16
requested bool
secondRead bool
questions []dnsmessage.Question
}
// Close implements net.Conn.
func (*testDnsConn) Close() error {
return nil
}
// LocalAddr implements net.Conn.
func (*testDnsConn) LocalAddr() net.Addr {
panic("unimplemented")
}
// Read implements net.Conn.
func (c *testDnsConn) Read(b []byte) (n int, err error) {
if c.secondRead {
data, err := c.data()
if err != nil {
return 0, err
}
copy(b, data)
return len(data), nil
}
data, err := c.data()
if err != nil {
return 0, err
}
l := len(data)
b[0] = byte(l >> 8)
b[1] = byte(l)
c.secondRead = true
return 2, nil
}
func (c *testDnsConn) data() ([]byte, error) {
if !c.requested {
builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{RCode: dnsmessage.RCodeNameError})
res, err := builder.Finish()
if err != nil {
return nil, err
}
return res, nil
}
msg := dnsmessage.Message{
Header: dnsmessage.Header{Response: true, Authoritative: true, RCode: dnsmessage.RCodeSuccess, ID: c.reqID},
Questions: c.questions,
Answers: []dnsmessage.Resource{},
}
c.appendIPv4(&msg)
c.appendIPV6(&msg)
return msg.Pack()
}
func (c *testDnsConn) appendIPv4(msg *dnsmessage.Message) {
if c.ipv4 != nil {
msg.Answers = append(msg.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(c.wantName),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AResource{A: [4]byte(c.ipv4)},
})
}
}
func (c *testDnsConn) appendIPV6(msg *dnsmessage.Message) {
if c.ipv6 != nil {
msg.Answers = append(msg.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(c.wantName),
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
},
Body: &dnsmessage.AAAAResource{AAAA: [16]byte(c.ipv6)},
})
}
}
// RemoteAddr implements net.Conn.
func (*testDnsConn) RemoteAddr() net.Addr {
panic("unimplemented")
}
// SetDeadline implements net.Conn.
func (*testDnsConn) SetDeadline(t time.Time) error {
return nil
}
// SetReadDeadline implements net.Conn.
func (*testDnsConn) SetReadDeadline(t time.Time) error {
return nil
}
// SetWriteDeadline implements net.Conn.
func (*testDnsConn) SetWriteDeadline(t time.Time) error {
return nil
}
// Write implements net.Conn.
func (c *testDnsConn) Write(b []byte) (n int, err error) {
var p dnsmessage.Parser
var h dnsmessage.Header
if h, err = p.Start(b[2:]); err != nil {
return 0, err
}
c.questions, err = p.AllQuestions()
if err != nil {
return 0, err
}
for _, q := range c.questions {
qStr := q.Name.String()
if qStr != c.wantName {
continue
}
c.requested = true
c.reqID = h.ID
if err := p.SkipAllQuestions(); err != nil {
return 0, err
}
break
}
return len(b), nil
}
type testConnStub struct{}
// Close implements net.Conn.
func (*testConnStub) Close() error {
return nil
}
// LocalAddr implements net.Conn.
func (*testConnStub) LocalAddr() net.Addr {
panic("unimplemented")
}
// Read implements net.Conn.
func (*testConnStub) Read(b []byte) (n int, err error) {
panic("unimplemented")
}
// RemoteAddr implements net.Conn.
func (*testConnStub) RemoteAddr() net.Addr {
panic("unimplemented")
}
// SetDeadline implements net.Conn.
func (*testConnStub) SetDeadline(t time.Time) error {
panic("unimplemented")
}
// SetReadDeadline implements net.Conn.
func (*testConnStub) SetReadDeadline(t time.Time) error {
panic("unimplemented")
}
// SetWriteDeadline implements net.Conn.
func (*testConnStub) SetWriteDeadline(t time.Time) error {
panic("unimplemented")
}
// Write implements net.Conn.
func (*testConnStub) Write(b []byte) (n int, err error) {
panic("unimplemented")
}