forked from TrueCloudLab/multinet
220 lines
5.1 KiB
Go
220 lines
5.1 KiB
Go
|
package multinet
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"net/netip"
|
||
|
"sort"
|
||
|
"sync"
|
||
|
)
|
||
|
|
||
|
// Dialer contains the single most important method from the net.Dialer.
|
||
|
type Dialer interface {
|
||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||
|
}
|
||
|
|
||
|
// Multidialer is like Dialer, but supports link state updates.
|
||
|
type Multidialer interface {
|
||
|
Dialer
|
||
|
UpdateInterface(name string, addr netip.Addr, status bool) error
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
_ Dialer = (*net.Dialer)(nil)
|
||
|
_ Dialer = (*dialer)(nil)
|
||
|
)
|
||
|
|
||
|
type dialer struct {
|
||
|
// Protects subnets field (recursively).
|
||
|
mtx sync.RWMutex
|
||
|
subnets []Subnet
|
||
|
// Default options for the net.Dialer.
|
||
|
dialer net.Dialer
|
||
|
// If true, allow to dial only configured subnets.
|
||
|
restrict bool
|
||
|
// Algorithm to which picks source address for each ip.
|
||
|
balancer balancer
|
||
|
// Hook used in tests, overrides `net.Dialer.DialContext()`
|
||
|
testHookDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
|
||
|
}
|
||
|
|
||
|
// Subnet represents a single subnet, possibly routable from multiple interfaces.
|
||
|
type Subnet struct {
|
||
|
Mask netip.Prefix
|
||
|
Interfaces []Source
|
||
|
}
|
||
|
|
||
|
// Source represents a single source IP belonging to a particular subnet.
|
||
|
type Source struct {
|
||
|
Name string
|
||
|
LocalAddr *net.TCPAddr
|
||
|
Down bool
|
||
|
}
|
||
|
|
||
|
// Config contains Multidialer configuration.
|
||
|
type Config struct {
|
||
|
// Routable subnets to prioritize in CIDR format.
|
||
|
Subnets []string
|
||
|
// If true, the only configurd subnets available through this dialer.
|
||
|
// Otherwise, a failback to the net.DefaultDialer.
|
||
|
Restrict bool
|
||
|
// Dialer containes default options for the net.Dialer to use.
|
||
|
// LocalAddr is overriden.
|
||
|
Dialer net.Dialer
|
||
|
// Balancer specifies algorithm used to pick source address.
|
||
|
Balancer BalancerType
|
||
|
}
|
||
|
|
||
|
// NewDialer ...
|
||
|
func NewDialer(c Config) (Multidialer, error) {
|
||
|
ifaces, err := net.Interfaces()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
sort.Slice(ifaces, func(i, j int) bool {
|
||
|
return ifaces[i].Name < ifaces[j].Name
|
||
|
})
|
||
|
|
||
|
var sources []iface
|
||
|
for i := range ifaces {
|
||
|
info, err := processIface(ifaces[i])
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
sources = append(sources, info)
|
||
|
}
|
||
|
|
||
|
var d dialer
|
||
|
for _, subnet := range c.Subnets {
|
||
|
s, err := processSubnet(subnet, sources)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
d.subnets = append(d.subnets, s)
|
||
|
}
|
||
|
|
||
|
switch c.Balancer {
|
||
|
case BalancerTypeNoop:
|
||
|
d.balancer = &firstEnabled{d: &d}
|
||
|
case BalancerTypeRoundRobin:
|
||
|
d.balancer = &roundRobin{d: &d}
|
||
|
default:
|
||
|
return nil, fmt.Errorf("invalid balancer type: %s", c.Balancer)
|
||
|
}
|
||
|
|
||
|
d.restrict = c.Restrict
|
||
|
|
||
|
return &d, nil
|
||
|
}
|
||
|
|
||
|
type iface struct {
|
||
|
info net.Interface
|
||
|
addrs []netip.Prefix
|
||
|
}
|
||
|
|
||
|
func processIface(info net.Interface) (iface, error) {
|
||
|
ips, err := info.Addrs()
|
||
|
if err != nil {
|
||
|
return iface{}, err
|
||
|
}
|
||
|
|
||
|
var addrs []netip.Prefix
|
||
|
for i := range ips {
|
||
|
p, err := netip.ParsePrefix(ips[i].String())
|
||
|
if err != nil {
|
||
|
return iface{}, err
|
||
|
}
|
||
|
|
||
|
addrs = append(addrs, p)
|
||
|
}
|
||
|
return iface{info: info, addrs: addrs}, nil
|
||
|
}
|
||
|
|
||
|
func processSubnet(subnet string, sources []iface) (Subnet, error) {
|
||
|
s, err := netip.ParsePrefix(subnet)
|
||
|
if err != nil {
|
||
|
return Subnet{}, err
|
||
|
}
|
||
|
|
||
|
var ifs []Source
|
||
|
for _, source := range sources {
|
||
|
for i := range source.addrs {
|
||
|
src := source.addrs[i].Addr()
|
||
|
if s.Contains(src) {
|
||
|
ifs = append(ifs, Source{
|
||
|
Name: source.info.Name,
|
||
|
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
sort.Slice(ifs, func(i, j int) bool {
|
||
|
if ifs[i].Name != ifs[j].Name {
|
||
|
return ifs[i].Name < ifs[j].Name
|
||
|
}
|
||
|
return bytes.Compare(ifs[i].LocalAddr.IP, ifs[j].LocalAddr.IP) == -1
|
||
|
})
|
||
|
|
||
|
return Subnet{
|
||
|
Mask: s,
|
||
|
Interfaces: ifs,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
// DialContext implements the Dialer interface.
|
||
|
// Hostnames for address are currently not supported.
|
||
|
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||
|
addr, err := netip.ParseAddrPort(address)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
d.mtx.RLock()
|
||
|
defer d.mtx.RUnlock()
|
||
|
|
||
|
for i := range d.subnets {
|
||
|
if d.subnets[i].Mask.Contains(addr.Addr()) {
|
||
|
return d.balancer.DialContext(ctx, &d.subnets[i], network, address)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if d.restrict {
|
||
|
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
|
||
|
}
|
||
|
return d.dialer.DialContext(ctx, network, address)
|
||
|
}
|
||
|
|
||
|
func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
|
||
|
if h := d.testHookDialContext; h != nil {
|
||
|
return h(nd, ctx, "tcp", address)
|
||
|
}
|
||
|
return nd.DialContext(ctx, "tcp", address)
|
||
|
}
|
||
|
|
||
|
// UpdateInterface implements the Multidialer interface.
|
||
|
// Updating address on a specific interface is currently not supported.
|
||
|
func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error {
|
||
|
d.mtx.Lock()
|
||
|
defer d.mtx.Unlock()
|
||
|
|
||
|
for i := range d.subnets {
|
||
|
for j := range d.subnets[i].Interfaces {
|
||
|
matchIface := d.subnets[i].Interfaces[j].Name == iface
|
||
|
if matchIface {
|
||
|
d.subnets[i].Interfaces[j].Down = !up
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
a, _ := netip.AddrFromSlice(d.subnets[i].Interfaces[j].LocalAddr.IP)
|
||
|
matchAddr := a.IsUnspecified() || addr == a
|
||
|
if matchAddr {
|
||
|
d.subnets[i].Interfaces[j].Down = !up
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|