multinet/balancer.go
2024-10-08 14:01:11 +03:00

64 lines
1.5 KiB
Go

package multinet
import (
"context"
"errors"
"fmt"
"net"
"sync/atomic"
)
// BalancerType reperents the algorithm which is used to pick source address.
type BalancerType string
const (
// BalancerTypeNoop picks first address for which link is up.
BalancerTypeNoop BalancerType = ""
// BalancerTypeNoop implements simple round-robin between up links.
// It is not fair in case some links are down.
BalancerTypeRoundRobin BalancerType = "roundrobin"
)
var errNoSuitableNodeFound = errors.New("no suitale node found")
type balancer interface {
DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error)
}
type roundRobin struct {
d *dialer
i atomic.Uint32
}
func (r *roundRobin) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) {
next := int(r.i.Add(1))
for i := range s.Interfaces {
ii := s.Interfaces[(i+next)%len(s.Interfaces)]
if ii.Down {
continue
}
dd := r.d.dialer
dd.LocalAddr = ii.LocalAddr
return r.d.dialContext(&dd, ctx, network, address)
}
return nil, fmt.Errorf("(*roundRobin).DialContext: %w", errNoSuitableNodeFound)
}
type firstEnabled struct {
d *dialer
}
func (r *firstEnabled) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) {
for i := range s.Interfaces {
ii := s.Interfaces[i%len(s.Interfaces)]
if ii.Down {
continue
}
dd := r.d.dialer
dd.LocalAddr = ii.LocalAddr
return r.d.dialContext(&dd, ctx, network, address)
}
return nil, fmt.Errorf("(*firstEnabled).DialContext: %w", errNoSuitableNodeFound)
}