multinet/dialer.go

220 lines
5.1 KiB
Go
Raw Normal View History

package multinet
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"sort"
"sync"
)
// Dialer contains the single most important method from the net.Dialer.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// Multidialer is like Dialer, but supports link state updates.
type Multidialer interface {
Dialer
UpdateInterface(name string, addr netip.Addr, status bool) error
}
var (
_ Dialer = (*net.Dialer)(nil)
_ Dialer = (*dialer)(nil)
)
type dialer struct {
// Protects subnets field (recursively).
mtx sync.RWMutex
subnets []Subnet
// Default options for the net.Dialer.
dialer net.Dialer
// If true, allow to dial only configured subnets.
restrict bool
// Algorithm to which picks source address for each ip.
balancer balancer
// Hook used in tests, overrides `net.Dialer.DialContext()`
testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
}
// Subnet represents a single subnet, possibly routable from multiple interfaces.
type Subnet struct {
Mask netip.Prefix
Interfaces []Source
}
// Source represents a single source IP belonging to a particular subnet.
type Source struct {
Name string
LocalAddr *net.TCPAddr
Down bool
}
// Config contains Multidialer configuration.
type Config struct {
// Routable subnets to prioritize in CIDR format.
Subnets []string
// If true, the only configurd subnets available through this dialer.
// Otherwise, a failback to the net.DefaultDialer.
Restrict bool
// Dialer containes default options for the net.Dialer to use.
// LocalAddr is overriden.
Dialer net.Dialer
// Balancer specifies algorithm used to pick source address.
Balancer BalancerType
}
// NewDialer ...
func NewDialer(c Config) (Multidialer, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
sort.Slice(ifaces, func(i, j int) bool {
return ifaces[i].Name < ifaces[j].Name
})
var sources []iface
for i := range ifaces {
info, err := processIface(ifaces[i])
if err != nil {
return nil, err
}
sources = append(sources, info)
}
var d dialer
for _, subnet := range c.Subnets {
s, err := processSubnet(subnet, sources)
if err != nil {
return nil, err
}
d.subnets = append(d.subnets, s)
}
switch c.Balancer {
case BalancerTypeNoop:
d.balancer = &firstEnabled{d: &d}
case BalancerTypeRoundRobin:
d.balancer = &roundRobin{d: &d}
default:
return nil, fmt.Errorf("invalid balancer type: %s", c.Balancer)
}
d.restrict = c.Restrict
return &d, nil
}
type iface struct {
info net.Interface
addrs []netip.Prefix
}
func processIface(info net.Interface) (iface, error) {
ips, err := info.Addrs()
if err != nil {
return iface{}, err
}
var addrs []netip.Prefix
for i := range ips {
p, err := netip.ParsePrefix(ips[i].String())
if err != nil {
return iface{}, err
}
addrs = append(addrs, p)
}
return iface{info: info, addrs: addrs}, nil
}
func processSubnet(subnet string, sources []iface) (Subnet, error) {
s, err := netip.ParsePrefix(subnet)
if err != nil {
return Subnet{}, err
}
var ifs []Source
for _, source := range sources {
for i := range source.addrs {
src := source.addrs[i].Addr()
if s.Contains(src) {
ifs = append(ifs, Source{
Name: source.info.Name,
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
})
}
}
}
sort.Slice(ifs, func(i, j int) bool {
if ifs[i].Name != ifs[j].Name {
return ifs[i].Name < ifs[j].Name
}
return bytes.Compare(ifs[i].LocalAddr.IP, ifs[j].LocalAddr.IP) == -1
})
return Subnet{
Mask: s,
Interfaces: ifs,
}, nil
}
// DialContext implements the Dialer interface.
// Hostnames for address are currently not supported.
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
addr, err := netip.ParseAddrPort(address)
if err != nil {
return nil, err
}
d.mtx.RLock()
defer d.mtx.RUnlock()
for i := range d.subnets {
if d.subnets[i].Mask.Contains(addr.Addr()) {
return d.balancer.DialContext(ctx, &d.subnets[i], network, address)
}
}
if d.restrict {
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
}
return d.dialer.DialContext(ctx, network, address)
}
func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if h := d.testHookDialContext; h != nil {
return h(nd, ctx, "tcp", address)
}
return nd.DialContext(ctx, "tcp", address)
}
// UpdateInterface implements the Multidialer interface.
// Updating address on a specific interface is currently not supported.
func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error {
d.mtx.Lock()
defer d.mtx.Unlock()
for i := range d.subnets {
for j := range d.subnets[i].Interfaces {
matchIface := d.subnets[i].Interfaces[j].Name == iface
if matchIface {
d.subnets[i].Interfaces[j].Down = !up
continue
}
a, _ := netip.AddrFromSlice(d.subnets[i].Interfaces[j].LocalAddr.IP)
matchAddr := a.IsUnspecified() || addr == a
if matchAddr {
d.subnets[i].Interfaces[j].Down = !up
}
}
}
return nil
}