multinet/dialer.go

328 lines
8.6 KiB
Go
Raw Normal View History

package multinet
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
)
const (
defaultFallbackDelay = 300 * time.Millisecond
)
// Dialer contains the single most important method from the net.Dialer.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
var (
_ Dialer = (*net.Dialer)(nil)
_ Dialer = (*dialer)(nil)
)
type dialer struct {
// Protects subnets field (recursively).
mtx sync.RWMutex
subnets []Subnet
// Default options for the net.Dialer.
dialer net.Dialer
// If true, allow to dial only configured subnets.
restrict bool
// Algorithm to which picks source address for each ip.
balancer balancer
// Overrides `net.Dialer.DialContext()` if specified.
customDialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
// Hostname resolver
resolver net.Resolver
// See Config.FallbackDelay description.
fallbackDelay time.Duration
// Event handler.
eh EventHandler
}
// Subnet represents a single subnet, possibly routable from multiple source IPs.
type Subnet struct {
Prefix netip.Prefix
SourceIPs []netip.Addr
}
type EventHandler interface {
DialPerformed(sourceIP net.Addr, network, address string, err error)
}
// Config contains Multidialer configuration.
type Config struct {
// Routable subnets.
Subnets []Subnet
// If true, the only configurd subnets available through this dialer.
// Otherwise, a failback to the net.DefaultDialer.
Restrict bool
// Dialer contains default options for the net.Dialer to use.
// LocalAddr is overridden.
Dialer net.Dialer
// Balancer specifies algorithm used to pick source address.
Balancer BalancerType
// FallbackDelay specifies the length of time to wait before
// spawning a RFC 6555 Fast Fallback connection. That is, this
// is the amount of time to wait for IPv6 to succeed before
// assuming that IPv6 is misconfigured and falling back to
// IPv4.
//
// If zero, a default delay of 300ms is used.
// A negative value disables Fast Fallback support.
FallbackDelay time.Duration
// DialContext is custom DialContext function.
// If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`).
DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
// EventHandler defines event handler.
EventHandler EventHandler
}
// NewDialer ...
func NewDialer(c Config) (Dialer, error) {
var d dialer
d.subnets = c.Subnets
switch c.Balancer {
case BalancerTypeNoop:
d.balancer = &firstEnabled{d: &d}
case BalancerTypeRoundRobin:
d.balancer = &roundRobin{d: &d}
default:
return nil, fmt.Errorf("invalid balancer type: %s", c.Balancer)
}
d.restrict = c.Restrict
d.resolver.Dial = d.dialContextIP
d.fallbackDelay = c.FallbackDelay
if d.fallbackDelay == 0 {
d.fallbackDelay = defaultFallbackDelay
}
d.dialer = c.Dialer
if c.DialContext != nil {
d.customDialContext = c.DialContext
}
if c.EventHandler != nil {
d.eh = c.EventHandler
} else {
d.eh = noopEventHandler{}
}
return &d, nil
}
// DialContext implements the Dialer interface.
// Hostnames for address are currently not supported.
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
addr, err := netip.ParseAddrPort(address)
if err != nil { // try resolve as hostname
return d.dialContextHostname(ctx, network, address)
}
return d.dialAddr(ctx, network, address, addr)
}
func (d *dialer) dialContextIP(ctx context.Context, network, address string) (net.Conn, error) {
addr, err := netip.ParseAddrPort(address)
if err != nil {
return nil, err
}
return d.dialAddr(ctx, network, address, addr)
}
func (d *dialer) dialContextHostname(ctx context.Context, network, address string) (net.Conn, error) {
// https://github.com/golang/go/blob/release-branch.go1.21/src/net/dial.go#L488
addrPorts, err := d.resolveHostIPs(ctx, address, network)
if err != nil {
return nil, err
}
var primaries, fallbacks []netip.AddrPort
if d.fallbackDelay >= 0 && network == "tcp" {
primaries, fallbacks = splitByType(addrPorts)
} else {
primaries = addrPorts
}
return d.dialParallel(ctx, network, primaries, fallbacks)
}
func (d *dialer) dialSerial(ctx context.Context, network string, addrs []netip.AddrPort) (net.Conn, error) {
var firstErr error
for _, addr := range addrs {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
conn, err := d.dialAddr(ctx, network, addr.String(), addr)
if err == nil {
return conn, nil
}
if firstErr == nil {
firstErr = err
}
}
return nil, fmt.Errorf("failed to connect any resolved address: %w", firstErr)
}
func (d *dialer) dialParallel(ctx context.Context, network string, primaries, fallbacks []netip.AddrPort) (net.Conn, error) {
if len(fallbacks) == 0 {
return d.dialSerial(ctx, network, primaries)
}
returned := make(chan struct{})
defer close(returned)
type dialResult struct {
net.Conn
error
primary bool
done bool
}
results := make(chan dialResult)
dialerFunc := func(ctx context.Context, primary bool) {
addrs := primaries
if !primary {
addrs = fallbacks
}
c, err := d.dialSerial(ctx, network, addrs)
select {
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
case <-returned:
if c != nil {
c.Close()
}
}
}
var primary, fallback dialResult
primaryCtx, primaryCancel := context.WithCancel(ctx)
defer primaryCancel()
go dialerFunc(primaryCtx, true)
fallbackTimer := time.NewTimer(d.fallbackDelay)
defer fallbackTimer.Stop()
for {
select {
case <-fallbackTimer.C:
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()
go dialerFunc(fallbackCtx, false)
case res := <-results:
if res.error == nil {
return res.Conn, nil
}
if res.primary {
primary = res
} else {
fallback = res
}
if primary.done && fallback.done {
return nil, primary.error
}
if res.primary && fallbackTimer.Stop() {
// If we were able to stop the timer, that means it
// was running (hadn't yet started the fallback), but
// we just got an error on the primary path, so start
// the fallback immediately (in 0 nanoseconds).
fallbackTimer.Reset(0)
}
}
}
}
func (d *dialer) resolveHostIPs(ctx context.Context, address string, network string) ([]netip.AddrPort, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("invalid address format: %w", err)
}
portnum, err := d.resolver.LookupPort(ctx, network, port)
if err != nil {
return nil, fmt.Errorf("invalid port format: %w", err)
}
ips, err := d.resolver.LookupHost(ctx, host)
if err != nil {
return nil, fmt.Errorf("failed to resolve host address: %w", err)
}
if len(ips) == 0 {
return nil, fmt.Errorf("failed to resolve address for [%s]%s", network, address)
}
var ipAddrs []netip.Addr
for _, ip := range ips {
ipAddr, err := netip.ParseAddr(ip)
if err != nil {
return nil, fmt.Errorf("failed to parse ip address '%s': %w", ip, err)
}
ipAddrs = append(ipAddrs, ipAddr)
}
var addrPorts []netip.AddrPort
for _, ipAddr := range ipAddrs {
addrPorts = append(addrPorts, netip.AddrPortFrom(ipAddr, uint16(portnum)))
}
return addrPorts, nil
}
func (d *dialer) dialAddr(ctx context.Context, network, address string, addr netip.AddrPort) (net.Conn, error) {
d.mtx.RLock()
defer d.mtx.RUnlock()
for i := range d.subnets {
if d.subnets[i].Prefix.Contains(addr.Addr()) {
return d.balancer.DialContext(ctx, &d.subnets[i], network, address)
}
}
if d.restrict {
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
}
return d.dialContext(ctx, &d.dialer, network, address)
}
func (d *dialer) dialContext(ctx context.Context, nd *net.Dialer, network, address string) (net.Conn, error) {
var conn net.Conn
var err error
if h := d.customDialContext; h != nil {
conn, err = h(nd, ctx, network, address)
} else {
conn, err = nd.DialContext(ctx, network, address)
}
d.eh.DialPerformed(nd.LocalAddr, network, address, err)
return conn, err
}
// splitByType divides an address list into two categories:
// the first address, and any with same type, are returned as
// primaries, while addresses with the opposite type are returned
// as fallbacks.
func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks []netip.AddrPort) {
var primaryLabel bool
for i, addr := range addrs {
label := addr.Addr().Is4()
if i == 0 || label == primaryLabel {
primaryLabel = label
primaries = append(primaries, addr)
} else {
fallbacks = append(fallbacks, addr)
}
}
return
}
type noopEventHandler struct{}
func (s noopEventHandler) DialPerformed(net.Addr, string, string, error) {}