forked from TrueCloudLab/multinet
431 lines
11 KiB
Go
431 lines
11 KiB
Go
package multinet
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
defaultFallbackDelay = 300 * time.Millisecond
|
|
)
|
|
|
|
// Dialer contains the single most important method from the net.Dialer.
|
|
type Dialer interface {
|
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
}
|
|
|
|
// Multidialer is like Dialer, but supports link state updates.
|
|
type Multidialer interface {
|
|
Dialer
|
|
UpdateInterface(name string, addr netip.Addr, status bool)
|
|
}
|
|
|
|
var (
|
|
_ Dialer = (*net.Dialer)(nil)
|
|
_ Dialer = (*dialer)(nil)
|
|
)
|
|
|
|
type dialer struct {
|
|
// Protects subnets field (recursively).
|
|
mtx sync.RWMutex
|
|
subnets []Subnet
|
|
// Default options for the net.Dialer.
|
|
dialer net.Dialer
|
|
// If true, allow to dial only configured subnets.
|
|
restrict bool
|
|
// Algorithm to which picks source address for each ip.
|
|
balancer balancer
|
|
// Overrides `net.Dialer.DialContext()` if specified.
|
|
customDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
|
|
// Hostname resolver
|
|
resolver net.Resolver
|
|
// See Config.FallbackDelay description.
|
|
fallbackDelay time.Duration
|
|
}
|
|
|
|
// Subnet represents a single subnet, possibly routable from multiple interfaces.
|
|
type Subnet struct {
|
|
Mask netip.Prefix
|
|
Interfaces []Source
|
|
}
|
|
|
|
// Source represents a single source IP belonging to a particular subnet.
|
|
type Source struct {
|
|
Name string
|
|
LocalAddr *net.TCPAddr
|
|
Down bool
|
|
}
|
|
|
|
// Config contains Multidialer configuration.
|
|
type Config struct {
|
|
// Routable subnets to prioritize in CIDR format.
|
|
Subnets []string
|
|
// If true, the only configurd subnets available through this dialer.
|
|
// Otherwise, a failback to the net.DefaultDialer.
|
|
Restrict bool
|
|
// Dialer containes default options for the net.Dialer to use.
|
|
// LocalAddr is overriden.
|
|
Dialer net.Dialer
|
|
// Balancer specifies algorithm used to pick source address.
|
|
Balancer BalancerType
|
|
// FallbackDelay specifies the length of time to wait before
|
|
// spawning a RFC 6555 Fast Fallback connection. That is, this
|
|
// is the amount of time to wait for IPv6 to succeed before
|
|
// assuming that IPv6 is misconfigured and falling back to
|
|
// IPv4.
|
|
//
|
|
// If zero, a default delay of 300ms is used.
|
|
// A negative value disables Fast Fallback support.
|
|
FallbackDelay time.Duration
|
|
// InterfaceSource is custom `Interface`` source.
|
|
// If not specified, default implementation is used (`net.Interfaces()``).
|
|
InterfaceSource func() ([]Interface, error)
|
|
// DialContext is custom DialContext function.
|
|
// If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`).
|
|
DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
|
|
}
|
|
|
|
// NewDialer ...
|
|
func NewDialer(c Config) (Multidialer, error) {
|
|
var ifaces []Interface
|
|
var err error
|
|
if c.InterfaceSource != nil {
|
|
ifaces, err = c.InterfaceSource()
|
|
} else {
|
|
ifaces, err = systemInterfaces()
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sort.Slice(ifaces, func(i, j int) bool {
|
|
return ifaces[i].Name() < ifaces[j].Name()
|
|
})
|
|
|
|
var sources []iface
|
|
for i := range ifaces {
|
|
info, err := processIface(ifaces[i])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sources = append(sources, info)
|
|
}
|
|
|
|
var d dialer
|
|
for _, subnet := range c.Subnets {
|
|
s, err := processSubnet(subnet, sources)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
d.subnets = append(d.subnets, s)
|
|
}
|
|
|
|
switch c.Balancer {
|
|
case BalancerTypeNoop:
|
|
d.balancer = &firstEnabled{d: &d}
|
|
case BalancerTypeRoundRobin:
|
|
d.balancer = &roundRobin{d: &d}
|
|
default:
|
|
return nil, fmt.Errorf("invalid balancer type: %s", c.Balancer)
|
|
}
|
|
|
|
d.restrict = c.Restrict
|
|
d.resolver.Dial = d.dialContextIP
|
|
d.fallbackDelay = c.FallbackDelay
|
|
if d.fallbackDelay == 0 {
|
|
d.fallbackDelay = defaultFallbackDelay
|
|
}
|
|
d.dialer = c.Dialer
|
|
if c.DialContext != nil {
|
|
d.customDialContext = c.DialContext
|
|
}
|
|
|
|
return &d, nil
|
|
}
|
|
|
|
type iface struct {
|
|
name string
|
|
addrs []netip.Prefix
|
|
down bool
|
|
}
|
|
|
|
func processIface(info Interface) (iface, error) {
|
|
ips, err := info.Addrs()
|
|
if err != nil {
|
|
return iface{}, err
|
|
}
|
|
|
|
var addrs []netip.Prefix
|
|
for i := range ips {
|
|
p, err := netip.ParsePrefix(ips[i].String())
|
|
if err != nil {
|
|
return iface{}, err
|
|
}
|
|
|
|
addrs = append(addrs, p)
|
|
}
|
|
return iface{name: info.Name(), addrs: addrs, down: info.Down()}, nil
|
|
}
|
|
|
|
func processSubnet(subnet string, sources []iface) (Subnet, error) {
|
|
s, err := netip.ParsePrefix(subnet)
|
|
if err != nil {
|
|
return Subnet{}, err
|
|
}
|
|
|
|
var ifs []Source
|
|
for _, source := range sources {
|
|
for i := range source.addrs {
|
|
src := source.addrs[i].Addr()
|
|
if s.Contains(src) {
|
|
ifs = append(ifs, Source{
|
|
Name: source.name,
|
|
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
|
|
Down: source.down,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
sort.Slice(ifs, func(i, j int) bool {
|
|
if ifs[i].Name != ifs[j].Name {
|
|
return ifs[i].Name < ifs[j].Name
|
|
}
|
|
return bytes.Compare(ifs[i].LocalAddr.IP, ifs[j].LocalAddr.IP) == -1
|
|
})
|
|
|
|
return Subnet{
|
|
Mask: s,
|
|
Interfaces: ifs,
|
|
}, nil
|
|
}
|
|
|
|
// DialContext implements the Dialer interface.
|
|
// Hostnames for address are currently not supported.
|
|
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
addr, err := netip.ParseAddrPort(address)
|
|
if err != nil { //try resolve as hostname
|
|
return d.dialContextHostname(ctx, network, address)
|
|
}
|
|
return d.dialAddr(ctx, network, address, addr)
|
|
}
|
|
|
|
func (d *dialer) dialContextIP(ctx context.Context, network, address string) (net.Conn, error) {
|
|
addr, err := netip.ParseAddrPort(address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return d.dialAddr(ctx, network, address, addr)
|
|
}
|
|
|
|
func (d *dialer) dialContextHostname(ctx context.Context, network, address string) (net.Conn, error) {
|
|
// https://github.com/golang/go/blob/release-branch.go1.21/src/net/dial.go#L488
|
|
addrPorts, err := d.resolveHostIPs(ctx, address, network)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var primaries, fallbacks []netip.AddrPort
|
|
if d.fallbackDelay >= 0 && network == "tcp" {
|
|
primaries, fallbacks = splitByType(addrPorts)
|
|
} else {
|
|
primaries = addrPorts
|
|
}
|
|
|
|
return d.dialParallel(ctx, network, primaries, fallbacks)
|
|
}
|
|
|
|
func (d *dialer) dialSerial(ctx context.Context, network string, addrs []netip.AddrPort) (net.Conn, error) {
|
|
var firstErr error
|
|
for _, addr := range addrs {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
default:
|
|
}
|
|
|
|
conn, err := d.dialAddr(ctx, network, addr.String(), addr)
|
|
if err == nil {
|
|
return conn, nil
|
|
}
|
|
if firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("failed to connect any resolved address: %w", firstErr)
|
|
}
|
|
|
|
func (d *dialer) dialParallel(ctx context.Context, network string, primaries, fallbacks []netip.AddrPort) (net.Conn, error) {
|
|
if len(fallbacks) == 0 {
|
|
return d.dialSerial(ctx, network, primaries)
|
|
}
|
|
|
|
returned := make(chan struct{})
|
|
defer close(returned)
|
|
|
|
type dialResult struct {
|
|
net.Conn
|
|
error
|
|
primary bool
|
|
done bool
|
|
}
|
|
results := make(chan dialResult)
|
|
|
|
dialerFunc := func(ctx context.Context, primary bool) {
|
|
addrs := primaries
|
|
if !primary {
|
|
addrs = fallbacks
|
|
}
|
|
c, err := d.dialSerial(ctx, network, addrs)
|
|
select {
|
|
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
|
|
case <-returned:
|
|
if c != nil {
|
|
c.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
var primary, fallback dialResult
|
|
|
|
primaryCtx, primaryCancel := context.WithCancel(ctx)
|
|
defer primaryCancel()
|
|
go dialerFunc(primaryCtx, true)
|
|
|
|
fallbackTimer := time.NewTimer(d.fallbackDelay)
|
|
defer fallbackTimer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-fallbackTimer.C:
|
|
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
|
|
defer fallbackCancel()
|
|
go dialerFunc(fallbackCtx, false)
|
|
|
|
case res := <-results:
|
|
if res.error == nil {
|
|
return res.Conn, nil
|
|
}
|
|
if res.primary {
|
|
primary = res
|
|
} else {
|
|
fallback = res
|
|
}
|
|
if primary.done && fallback.done {
|
|
return nil, primary.error
|
|
}
|
|
if res.primary && fallbackTimer.Stop() {
|
|
// If we were able to stop the timer, that means it
|
|
// was running (hadn't yet started the fallback), but
|
|
// we just got an error on the primary path, so start
|
|
// the fallback immediately (in 0 nanoseconds).
|
|
fallbackTimer.Reset(0)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *dialer) resolveHostIPs(ctx context.Context, address string, network string) ([]netip.AddrPort, error) {
|
|
host, port, err := net.SplitHostPort(address)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid address format: %w", err)
|
|
}
|
|
|
|
portnum, err := d.resolver.LookupPort(ctx, network, port)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid port format: %w", err)
|
|
}
|
|
|
|
ips, err := d.resolver.LookupHost(ctx, host)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to resolve host address: %w", err)
|
|
}
|
|
|
|
if len(ips) == 0 {
|
|
return nil, fmt.Errorf("failed to resolve address for [%s]%s", network, address)
|
|
}
|
|
|
|
var ipAddrs []netip.Addr
|
|
for _, ip := range ips {
|
|
ipAddr, err := netip.ParseAddr(ip)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse ip address '%s': %w", ip, err)
|
|
}
|
|
ipAddrs = append(ipAddrs, ipAddr)
|
|
}
|
|
|
|
var addrPorts []netip.AddrPort
|
|
for _, ipAddr := range ipAddrs {
|
|
addrPorts = append(addrPorts, netip.AddrPortFrom(ipAddr, uint16(portnum)))
|
|
}
|
|
return addrPorts, nil
|
|
}
|
|
|
|
func (d *dialer) dialAddr(ctx context.Context, network, address string, addr netip.AddrPort) (net.Conn, error) {
|
|
d.mtx.RLock()
|
|
defer d.mtx.RUnlock()
|
|
|
|
for i := range d.subnets {
|
|
if d.subnets[i].Mask.Contains(addr.Addr()) {
|
|
return d.balancer.DialContext(ctx, &d.subnets[i], network, address)
|
|
}
|
|
}
|
|
|
|
if d.restrict {
|
|
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
|
|
}
|
|
return d.dialContext(&d.dialer, ctx, network, address)
|
|
}
|
|
|
|
func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
|
|
if h := d.customDialContext; h != nil {
|
|
return h(nd, ctx, network, address)
|
|
}
|
|
return nd.DialContext(ctx, network, address)
|
|
}
|
|
|
|
// UpdateInterface implements the Multidialer interface.
|
|
// Updating address on a specific interface is currently not supported.
|
|
func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) {
|
|
d.mtx.Lock()
|
|
defer d.mtx.Unlock()
|
|
|
|
for i := range d.subnets {
|
|
for j := range d.subnets[i].Interfaces {
|
|
matchIface := d.subnets[i].Interfaces[j].Name == iface
|
|
if matchIface {
|
|
d.subnets[i].Interfaces[j].Down = !up
|
|
continue
|
|
}
|
|
|
|
a, _ := netip.AddrFromSlice(d.subnets[i].Interfaces[j].LocalAddr.IP)
|
|
matchAddr := a.IsUnspecified() || addr == a
|
|
if matchAddr {
|
|
d.subnets[i].Interfaces[j].Down = !up
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// splitByType divides an address list into two categories:
|
|
// the first address, and any with same type, are returned as
|
|
// primaries, while addresses with the opposite type are returned
|
|
// as fallbacks.
|
|
func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks []netip.AddrPort) {
|
|
var primaryLabel bool
|
|
for i, addr := range addrs {
|
|
label := addr.Addr().Is4()
|
|
if i == 0 || label == primaryLabel {
|
|
primaryLabel = label
|
|
primaries = append(primaries, addr)
|
|
} else {
|
|
fallbacks = append(fallbacks, addr)
|
|
}
|
|
}
|
|
return
|
|
}
|